diff --git a/main.go b/main.go index d8e5100..03e8f0e 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" "github.com/gorilla/sessions" + "github.com/gorilla/csrf" _ "github.com/lib/pq" "net/http" "os" @@ -44,6 +45,7 @@ func loadConfig(env string) { fmt.Println("Error: cannot open config file: config/" + env + ".json") os.Exit(-1) } + defer file.Close(); decoder := json.NewDecoder(file) err = decoder.Decode(&config) if err != nil { @@ -67,6 +69,21 @@ func dbConnect() { } } +func csrfSecret() string { + file, err := os.Open("csrf-secret.txt") + if err != nil { + fmt.Println("Error: cannot open csrf-secret.txt (run gen-csrf.sh to generate)") + os.Exit(-1) + } + defer file.Close(); + var bytes []byte = make([]byte, 32) + n, err := file.Read(bytes) + if err != nil || n != 32 { + fmt.Println("Error: cannot open csrf-secret.txt (run gen-csrf.sh to generate)") + } + return string(bytes) +} + func main() { fmt.Println("transmet ", VERSION) @@ -78,10 +95,10 @@ func main() { loadConfig(*envFlag) dbConnect() initTemplates() - init_route_handlers() + r := init_route_handlers() fmt.Println("Listening on", config.Port, "...") - err := http.ListenAndServe(":"+config.Port, nil) + err := http.ListenAndServe(":"+config.Port, csrf.Protect([]byte(csrfSecret()))(r)) if err != nil { fmt.Println("Fatal Error: ", err) } diff --git a/route_handlers.go b/route_handlers.go index 0adaabf..9dc4ded 100644 --- a/route_handlers.go +++ b/route_handlers.go @@ -420,7 +420,7 @@ func ServeFileHandler(res http.ResponseWriter, req *http.Request) { http.ServeFile(res, req, "./"+fname) } -func init_route_handlers() { +func init_route_handlers() *mux.Router { http.Handle("/js/", http.StripPrefix("/js/", http.FileServer(http.Dir("js/")))) http.Handle("/css/", http.StripPrefix("/css/", http.FileServer(http.Dir("css/")))) http.Handle("/fonts/", http.StripPrefix("/fonts", http.FileServer(http.Dir("fonts/")))) @@ -447,4 +447,6 @@ func init_route_handlers() { r.HandleFunc("/categories/delete", userHandler(categoryDeleteHandler)) http.Handle("/", r) + + return r }