Jelajahi Sumber

complete rewrite

teizz 3 tahun lalu
induk
melakukan
84945beb02
1 mengubah file dengan 111 tambahan dan 102 penghapusan
  1. 111 102
      main.go

+ 111 - 102
main.go

@@ -1,150 +1,159 @@
 package main
 
 import (
-	"fmt"
 	"io"
 	"log"
 	"net/http"
-	"strconv"
 	"sync"
+	"sync/atomic"
 )
 
 var (
+	// variables to set during build-time
+	debugging = ""
 	version   = "0.0-undefined"
 	buildtime = "0000-00-00T00:00:00+0000"
-	queues    = make(map[string]*Queue)
-	mut       = &sync.Mutex{}
+
+	// actual business end of the device
+	paths = &sync.Map{}
 )
 
+// Transfer holds a single tranferable connection to be read
 type Transfer struct {
-	reader io.Reader
-	size   int
+	reader        *io.PipeReader
+	done          chan struct{}
+	contentlength string
 }
 
+// Queue is where posts and gets can exchange transfers
 type Queue struct {
-	ch        chan Transfer
-	producers int
-	consumers int
-}
-
-func NewQueue() *Queue {
-	return &Queue{
-		ch: make(chan Transfer),
-	}
-}
-
-func (q *Queue) addConsumer() {
-	q.consumers = q.consumers + 1
+	ch    chan Transfer
+	posts int32
+	gets  int32
 }
 
-func (q *Queue) remConsumer() {
-	q.consumers = q.consumers - 1
-}
-
-func (q *Queue) addProducer() {
-	q.producers = q.producers + 1
-}
+func pathHandler(w http.ResponseWriter, r *http.Request) {
+	pathID := r.URL.Path
 
-func (q *Queue) remProducer() {
-	q.producers = q.producers - 1
-}
-
-func (q *Queue) isEmpty() bool {
-	return q.producers == 0 && q.consumers == 0
-}
+	if len(pathID) < 2 {
+		w.WriteHeader(400)
+		w.Write([]byte("path to short"))
+		return
+	}
 
-func main() {
-	log.Printf("pathway version:%s buildtime:%s", version, buildtime)
+	if r.Method == "GET" {
+		log.Printf("%s [GET] Connected", pathID)
 
-	handle := func(w http.ResponseWriter, r *http.Request) {
-		channelId := r.URL.Path
+		queue := &Queue{ch: make(chan Transfer)}
+		if p, loaded := paths.LoadOrStore(pathID, queue); loaded {
+			queue = p.(*Queue)
+			debug("%s [GET] Loads path", pathID)
+		} else {
+			debug("%s [GET] Created path", pathID)
+		}
+		atomic.AddInt32(&queue.gets, 1)
 
-		if r.Method == "GET" {
-			log.Printf("GET: %s", channelId)
-			mut.Lock()
-			queue, exists := queues[channelId]
-			if !exists {
-				queue = NewQueue()
-				queues[channelId] = queue
+		select {
+		case transfer := <-queue.ch:
+			debug("%s [GET] Reads from path", pathID)
+			if transfer.contentlength != "" {
+				w.Header().Set("Content-Length", transfer.contentlength)
 			}
-
-			ch := queue.ch
-
-			queue.addConsumer()
-			mut.Unlock()
-
-			select {
-			case transfer := <-ch:
-				w.Header().Set("Content-Length", fmt.Sprintf("%d", transfer.size))
-				_, err := io.Copy(w, transfer.reader)
-				if err != nil {
-					if closer, ok := transfer.reader.(io.Closer); ok {
-						closer.Close()
-					}
-				}
-			case <-r.Context().Done():
+			_, err := io.Copy(w, transfer.reader)
+			if err != nil {
+				transfer.reader.Close()
 			}
+			debug("%s [GET] Sends done", pathID)
+			close(transfer.done)
+		case <-r.Context().Done():
+			debug("%s [GET] Cancels path", pathID)
+		}
 
-			mut.Lock()
-			queue.remConsumer()
-			if queue.isEmpty() {
-				delete(queues, channelId)
+		if atomic.AddInt32(&queue.gets, -1) <= 0 {
+			if atomic.LoadInt32(&queue.posts) <= 0 {
+				paths.Delete(pathID)
+				debug("%s [GET] Removes path", pathID)
 			}
-			mut.Unlock()
-		} else {
-			log.Printf("POST: %s", channelId)
-			mut.Lock()
-			queue, exists := queues[channelId]
-			if !exists {
-				queue = NewQueue()
-				queues[channelId] = queue
-			}
-
-			ch := queue.ch
+		}
+		log.Printf("%s [GET] Finishes", pathID)
 
-			queue.addProducer()
-			mut.Unlock()
+	} else {
+		log.Printf("%s [POST] Connected", pathID)
 
-			reader, writer := io.Pipe()
+		queue := &Queue{ch: make(chan Transfer)}
+		if p, loaded := paths.LoadOrStore(pathID, queue); loaded {
+			queue = p.(*Queue)
+			debug("%s [POST] Loads path", pathID)
+		} else {
+			debug("%s [POST] Creates path", pathID)
+		}
+		atomic.AddInt32(&queue.posts, 1)
 
-			contentLength, err := strconv.Atoi(r.Header.Get("Content-Length"))
-			if err != nil {
-				contentLength = 0
-			}
+		reader, writer := io.Pipe()
 
-			transfer := Transfer{
-				reader: reader,
-				size:   contentLength,
-			}
+		transfer := Transfer{
+			reader:        reader,
+			contentlength: r.Header.Get("Content-Length"),
+			done:          make(chan struct{}),
+		}
 
-			select {
-			case ch <- transfer:
-				io.Copy(writer, r.Body)
-			case <-r.Context().Done():
+		go func() {
+			n, err := io.Copy(writer, r.Body)
+			debug("%s [POST] Sends %d bytes", pathID, n)
+			if err != nil {
+				debug("%s [POST] Has error: %s", pathID, err.Error())
 			}
-
 			writer.Close()
+			r.Body.Close()
+		}()
+
+		select {
+		case queue.ch <- transfer:
+			debug("%s [POST] Writes to path", pathID)
+		case <-r.Context().Done():
+			debug("%s [POST] Cancels path", pathID)
+			close(transfer.done)
+		}
+
+		debug("%s [POST] Waits for done", pathID)
+		<-transfer.done
 
-			mut.Lock()
-			queue.remProducer()
-			if queue.isEmpty() {
-				delete(queues, channelId)
+		if atomic.AddInt32(&queue.posts, -1) <= 0 {
+			if atomic.LoadInt32(&queue.gets) <= 0 {
+				paths.Delete(pathID)
+				debug("%s [POST] Removes path", pathID)
 			}
-			mut.Unlock()
 		}
-	}
 
-	health := func(w http.ResponseWriter, r *http.Request) {
-		w.WriteHeader(http.StatusOK)
-		w.Write([]byte("OK"))
+		log.Printf("%s [POST] Finishes", pathID)
 	}
+}
+
+func okHandler(w http.ResponseWriter, r *http.Request) {
+	w.WriteHeader(http.StatusOK)
+	w.Write([]byte("ok"))
+}
 
-	http.HandleFunc("/health", health)
-	http.HandleFunc("/", handle)
+func emptyHandler(w http.ResponseWriter, r *http.Request) {
+	w.WriteHeader(http.StatusOK)
+}
+
+func main() {
+	log.Printf("pathway version:%s buildtime:%s", version, buildtime)
+
+	http.HandleFunc("/health", okHandler)
+	http.HandleFunc("/favicon.ico", emptyHandler)
+	http.HandleFunc("/robots.txt", emptyHandler)
+	http.HandleFunc("/", pathHandler)
 
 	err := http.ListenAndServe(":8080", nil)
 	if err != nil {
 		log.Println(err)
 	}
+}
 
+func debug(msg string, args ...interface{}) {
+	if len(debugging) > 0 {
+		log.Printf(msg, args...)
+	}
 }