diff --git a/constants.go b/constants.go index aa9b41e..a8c54f0 100644 --- a/constants.go +++ b/constants.go @@ -4,4 +4,8 @@ const ( // Header constants headerUUID = "uuid" headerTimestamp = "time" + + // Data packet types + kindData = "data" + kindHeartbeat = "heartbeat" ) diff --git a/data-types.go b/data-types.go index 580dff5..db5e1c2 100644 --- a/data-types.go +++ b/data-types.go @@ -1,6 +1,52 @@ package main +import "sync" + type StampedReading struct { - Timestamp int - Value interface{} + Timestamp int `json:"timestamp"` + Value interface{} `json:"value"` +} + +type Packet struct { + ID int `json:"id"` + Kind string `json:"kind"` + Message []byte `json:"message"` +} + +type generator struct { + databaseId int + serverSn string + socket *socketConn +} + +type generatorMap struct { + lock sync.Mutex + mapping map[string]*generator +} + +func newGenMap() *generatorMap { + return &generatorMap{sync.Mutex{}, make(map[string]*generator)} } + +func (g *generatorMap) add(k string, v *generator) { + g.lock.Lock() + defer g.lock.Unlock() + g.mapping[k] = v +} + +type IDMap struct { + lock sync.Mutex + mapping map[int]string +} + +func newIDMap() *IDMap { + return &IDMap{sync.Mutex{}, make(map[int]string)} +} + +func (m *IDMap) add(k int, v string) { + m.lock.Lock() + defer m.lock.Unlock() + m.mapping[k] = v +} + +// TODO: maybe break this up a bit diff --git a/gateway.go b/gateway.go index 409faa4..1694a6b 100644 --- a/gateway.go +++ b/gateway.go @@ -1,18 +1,35 @@ package main import ( + "database/sql" "encoding/json" "fmt" + _ "github.com/go-sql-driver/mysql" + "github.com/gorilla/websocket" "io/ioutil" "net/http" + "os" + "strings" + "sync" ) // TODO: log to file +var genMap = newGenMap() +var DBIDMap = newIDMap() +var dbConn *sql.DB +var upgrader = websocket.Upgrader{} + func main() { - http.HandleFunc("/data", dataHandler) - http.HandleFunc("/heartbeat", heartbeatHandler) + dbConnect() // Connect to the database + defer func() { // TODO: handle SIGINT + if err := dbConn.Close(); err != nil { + fmt.Printf("Error closing db connection: %v\n", err) + } + }() + + http.HandleFunc("/ws", connectionHandler) // All incoming websocket connections use this if err := http.ListenAndServeTLS( ":48820", @@ -23,23 +40,101 @@ func main() { } } -func dataHandler(w http.ResponseWriter, r *http.Request) { - fmt.Printf("Got data from: %v at %v\n", r.Header.Get("uuid"), r.Header.Get("time")) - body, _ := ioutil.ReadAll(r.Body) - var data map[string][]StampedReading - if err := json.Unmarshal(body, &data); err != nil { - fmt.Printf("Error unmarshalling body: %v\n", err) +func dbConnect() { + sqlConf, err := ioutil.ReadFile("/root/.mysql_goconf") // Read the db config file + if err != nil { + fmt.Printf("Error reading MySQL config file: %v\n", err) + os.Exit(1) // If we can't connect to the database there's not much point in continuing + } + + connInfo := strings.Fields(string(sqlConf)) + + dbConn, err = sql.Open("mysql", fmt.Sprintf("%v:%v@tcp(%v)/%v", + connInfo[0], + connInfo[1], + connInfo[2], + connInfo[3])) + + if err != nil { + fmt.Printf("Error connecting to the database: %v\n", err) + // TODO: maybe not die here and probably handle the database going down and coming back up + os.Exit(1) + } +} + +type socketConn struct { + // Struct to handle our websocket connection + conn *websocket.Conn + lock sync.Mutex + genId string +} + +func connectionHandler(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) // Upgrade to websocket + if err != nil { + fmt.Printf("Error upgrading websocket: %v\n", err) + return } else { - fmt.Printf("Data received: %v\n", data) + fmt.Printf("Got connection from %v@%v\n", r.Header.Get(headerUUID), conn.RemoteAddr()) + } + ws := &socketConn{conn: conn, lock: sync.Mutex{}} + + g := &generator{socket: ws} + row := dbConn.QueryRow( // Figure out the database id of this generator + "SELECT gen_id, server_id FROM generators WHERE server_id = ?", + r.Header.Get(headerUUID)) + if err := row.Scan(&g.databaseId, &g.serverSn); err != nil { // Dump that info into our generator struct + fmt.Printf("Error with database results: %v\n", err) + } else { + fmt.Printf("Got generator: %v\n", g) + } + ws.genId = g.serverSn + genMap.add(g.serverSn, g) // Add to map of SNs to generator IDs/ws connections + DBIDMap.add(g.databaseId, g.serverSn) // Add to map of database IDs to SNs + + for { // Do forever + var packet Packet + err := conn.ReadJSON(&packet) // Get our next packet + if err != nil { + fmt.Printf("Error reading message: %v\n", err) + return // End this connection + } else { + switch packet.Kind { + case kindData: + go ws.handleData(packet) + case kindHeartbeat: + go ws.handleHeartbeat(packet) + default: + fmt.Printf("Received unknown packet type: %v\n", packet.Kind) + } + } } - w.WriteHeader(200) } -func heartbeatHandler(w http.ResponseWriter, r *http.Request) { - fmt.Printf("Got heartbeat from %v@%v at %v\n", - r.Header.Get(headerUUID), - r.RemoteAddr, - r.Header.Get(headerTimestamp)) +func (ws *socketConn) handleData(data Packet) { + var msg map[string][]StampedReading // The message should always be a byte slice that will unmarshal into this + if err := json.Unmarshal(data.Message, &msg); err != nil { + fmt.Printf("Error unmarshalling data: %v\n", err) + } else { + fmt.Printf("got some data: %v\n", msg) + // TODO: Better response handling + ws.lock.Lock() + if err := ws.conn.WriteJSON(Packet{data.ID, "response", []byte("received")}); err != nil { + fmt.Printf("Error sending response: %v\n", err) + } else { + fmt.Printf("Sent response for packet %v\n", data.ID) + } + ws.lock.Unlock() + } +} - w.WriteHeader(200) +func (ws *socketConn) handleHeartbeat(data Packet) { + fmt.Printf("Heartbeat from %v\n", ws.conn.RemoteAddr()) + ws.lock.Lock() + if err := ws.conn.WriteJSON(Packet{data.ID, "response", []byte("received")}); err != nil { + fmt.Printf("Error sending response: %v\n", err) + } else { + fmt.Printf("Sent response for packet %v\n", data.ID) + } + ws.lock.Unlock() }