diff --git a/main.go b/main.go index 03e8f0e..168b4c3 100644 --- a/main.go +++ b/main.go @@ -5,8 +5,8 @@ import ( "encoding/json" "flag" "fmt" - "github.com/gorilla/sessions" "github.com/gorilla/csrf" + "github.com/gorilla/sessions" _ "github.com/lib/pq" "net/http" "os" @@ -45,7 +45,7 @@ func loadConfig(env string) { fmt.Println("Error: cannot open config file: config/" + env + ".json") os.Exit(-1) } - defer file.Close(); + defer file.Close() decoder := json.NewDecoder(file) err = decoder.Decode(&config) if err != nil { @@ -75,7 +75,7 @@ func csrfSecret() string { fmt.Println("Error: cannot open csrf-secret.txt (run gen-csrf.sh to generate)") os.Exit(-1) } - defer file.Close(); + defer file.Close() var bytes []byte = make([]byte, 32) n, err := file.Read(bytes) if err != nil || n != 32 { diff --git a/route_handlers.go b/route_handlers.go index 9dc4ded..f6ee6ac 100644 --- a/route_handlers.go +++ b/route_handlers.go @@ -69,7 +69,7 @@ func LoginFormHandler(w http.ResponseWriter, r *http.Request) { flashes := GetFlashes(session) session.Save(r, w) - ShowTemplate("login", w, map[string]interface{}{"flashes": flashes}) + ShowTemplate("login", w, r, map[string]interface{}{"flashes": flashes}) } // handler for login POST @@ -146,7 +146,7 @@ func addFormHandler(w http.ResponseWriter, r *http.Request, user *user.User, ses popup := r.URL.Query().Get("popup") - ShowTemplate("post", w, map[string]interface{}{"mode": "add", "user": user, "flashes": flashes, "link": url, "categories": categories.CategoriesTree, "title": title, "popup": popup, "category_id": -1}) + ShowTemplate("post", w, r, map[string]interface{}{"mode": "add", "user": user, "flashes": flashes, "link": url, "categories": categories.CategoriesTree, "title": title, "popup": popup, "category_id": -1}) } func addPostHandler(w http.ResponseWriter, r *http.Request, user *user.User, session *sessions.Session) { @@ -159,8 +159,8 @@ func addPostHandler(w http.ResponseWriter, r *http.Request, user *user.User, ses category_id, err := strconv.Atoi(r.FormValue("category")) if err != nil { var flashes = make(map[string]interface{}) - flashes["error"] = []string{ "Category required: " +err.Error() } - ShowTemplate("post", w, map[string]interface{}{"mode": "add", "user": user, "flashes": flashes, "link": news.Url, "categories": categories.CategoriesTree, "title": news.Title, "popup": popup, "notes": news.Notes, "category_id": news.Category_id}) + flashes["error"] = []string{"Category required: " + err.Error()} + ShowTemplate("post", w, r, map[string]interface{}{"mode": "add", "user": user, "flashes": flashes, "link": news.Url, "categories": categories.CategoriesTree, "title": news.Title, "popup": popup, "notes": news.Notes, "category_id": news.Category_id}) return } news.Category_id = category_id @@ -168,8 +168,8 @@ func addPostHandler(w http.ResponseWriter, r *http.Request, user *user.User, ses err = news.Insert(db) if err != nil { var flashes = make(map[string]interface{}) - flashes["error"] = []string{ "Error saving news: "+err.Error() } - ShowTemplate("post", w, map[string]interface{}{"mode": "add", "user": user, "flashes": flashes, "link": news.Url, "categories": categories.CategoriesTree, "title": news.Title, "popup": popup, "notes": news.Notes, "category_id": news.Category_id}) + flashes["error"] = []string{"Error saving news: " + err.Error()} + ShowTemplate("post", w, r, map[string]interface{}{"mode": "add", "user": user, "flashes": flashes, "link": news.Url, "categories": categories.CategoriesTree, "title": news.Title, "popup": popup, "notes": news.Notes, "category_id": news.Category_id}) return } else { session.AddFlash("Added news \""+news.Title+"\"", flash_info) @@ -197,14 +197,14 @@ func editFormHandler(w http.ResponseWriter, r *http.Request, user *user.User, se newsItem, err := news.Get(db, id) if err != nil { - session.AddFlash("Could not load news item " + strconv.Itoa(id), flash_err) + session.AddFlash("Could not load news item "+strconv.Itoa(id), flash_err) session.Save(r, w) http.Redirect(w, r, "/news", http.StatusFound) return } session.Save(r, w) - ShowTemplate("post", w, map[string]interface{}{"mode": "edit", "user": user, "flashes": flashes, "categories": categories.CategoriesTree, "link": newsItem.Url, "title": newsItem.Title, "notes": newsItem.Notes, "popup": false, "category_id": newsItem.Category_id, "id": newsItem.Id()}) + ShowTemplate("post", w, r, map[string]interface{}{"mode": "edit", "user": user, "flashes": flashes, "categories": categories.CategoriesTree, "link": newsItem.Url, "title": newsItem.Title, "notes": newsItem.Notes, "popup": false, "category_id": newsItem.Category_id, "id": newsItem.Id()}) return } @@ -233,8 +233,8 @@ func editPostHandler(w http.ResponseWriter, r *http.Request, user *user.User, se category_id, err := strconv.Atoi(r.FormValue("category")) if err != nil { var flashes = make(map[string]interface{}) - flashes["error"] = []string{ "Category required: " +err.Error() } - ShowTemplate("post", w, map[string]interface{}{"mode": "edit", "user": user, "flashes": flashes, "link": news.Url, "categories": categories.CategoriesTree, "title": news.Title, "popup": false, "notes": news.Notes, "category_id": news.Category_id, "id": news.Id()}) + flashes["error"] = []string{"Category required: " + err.Error()} + ShowTemplate("post", w, r, map[string]interface{}{"mode": "edit", "user": user, "flashes": flashes, "link": news.Url, "categories": categories.CategoriesTree, "title": news.Title, "popup": false, "notes": news.Notes, "category_id": news.Category_id, "id": news.Id()}) return } news.Category_id = category_id @@ -242,8 +242,8 @@ func editPostHandler(w http.ResponseWriter, r *http.Request, user *user.User, se err = news.Update(db) if err != nil { var flashes = make(map[string]interface{}) - flashes["error"] = []string{ "Error saving news: "+err.Error() } - ShowTemplate("post", w, map[string]interface{}{"mode": "edit", "user": user, "flashes": flashes, "link": news.Url, "categories": categories.CategoriesTree, "title": news.Title, "popup": false, "notes": news.Notes, "category_id": news.Category_id, "id": news.Id()}) + flashes["error"] = []string{"Error saving news: " + err.Error()} + ShowTemplate("post", w, r, map[string]interface{}{"mode": "edit", "user": user, "flashes": flashes, "link": news.Url, "categories": categories.CategoriesTree, "title": news.Title, "popup": false, "notes": news.Notes, "category_id": news.Category_id, "id": news.Id()}) return } else { session.AddFlash("Updated news \""+news.Title+"\"", flash_info) @@ -252,7 +252,6 @@ func editPostHandler(w http.ResponseWriter, r *http.Request, user *user.User, se } } - func templateFormHandler(w http.ResponseWriter, r *http.Request, user *user.User, session *sessions.Session) { flashes := GetFlashes(session) session.Save(r, w) @@ -272,7 +271,7 @@ func templateFormHandler(w http.ResponseWriter, r *http.Request, user *user.User fmt.Println("Exec err: ", err) } - ShowTemplate("export", w, map[string]interface{}{"user": user, "flashes": flashes, "template": &templateBuf, "count": count, "url": config.Url}) + ShowTemplate("export", w, r, map[string]interface{}{"user": user, "flashes": flashes, "template": &templateBuf, "count": count, "url": config.Url}) } func exportHandler(w http.ResponseWriter, r *http.Request, user *user.User, session *sessions.Session) { @@ -290,7 +289,7 @@ func exportHandler(w http.ResponseWriter, r *http.Request, user *user.User, sess func addedHandler(w http.ResponseWriter, r *http.Request, user *user.User, session *sessions.Session) { flashes := GetFlashes(session) session.Save(r, w) - ShowTemplate("added", w, map[string]interface{}{"user": user, "flashes": flashes}) + ShowTemplate("added", w, r, map[string]interface{}{"user": user, "flashes": flashes}) } func deleteHandler(w http.ResponseWriter, r *http.Request, user *user.User, session *sessions.Session) { @@ -316,7 +315,7 @@ func categoriesFormHandler(w http.ResponseWriter, r *http.Request, user *user.Us session.Save(r, w) categories.LoadCategories(db) - ShowTemplate("categories", w, map[string]interface{}{"user": user, "flashes": flashes, "categories": categories.CategoriesTree}) + ShowTemplate("categories", w, r, map[string]interface{}{"user": user, "flashes": flashes, "categories": categories.CategoriesTree}) } func categoriesPostHandler(w http.ResponseWriter, r *http.Request, user *user.User, session *sessions.Session) { @@ -411,7 +410,7 @@ func newsFormHandler(w http.ResponseWriter, r *http.Request, user *user.User, se session.AddFlash("Error loading news", flash_err) } - ShowTemplate("news", w, map[string]interface{}{"user": user, "flashes": flashes, "news": news, "count": count, "categories": categories.CategoriesFlat}) + ShowTemplate("news", w, r, map[string]interface{}{"user": user, "flashes": flashes, "news": news, "count": count, "categories": categories.CategoriesFlat}) } @@ -421,14 +420,13 @@ func ServeFileHandler(res http.ResponseWriter, req *http.Request) { } func init_route_handlers() *mux.Router { - http.Handle("/js/", http.StripPrefix("/js/", http.FileServer(http.Dir("js/")))) - http.Handle("/css/", http.StripPrefix("/css/", http.FileServer(http.Dir("css/")))) - http.Handle("/fonts/", http.StripPrefix("/fonts", http.FileServer(http.Dir("fonts/")))) - http.HandleFunc("/favicon.ico", ServeFileHandler) - r := mux.NewRouter() - // TODO: CSRF + r.Handle("/js/", http.StripPrefix("/js/", http.FileServer(http.Dir("js/")))) + r.Handle("/css/", http.StripPrefix("/css/", http.FileServer(http.Dir("css/")))) + r.Handle("/fonts/", http.StripPrefix("/fonts", http.FileServer(http.Dir("fonts/")))) + r.HandleFunc("/favicon.ico", ServeFileHandler) + r.HandleFunc("/login", getPostHandler(LoginFormHandler, LoginPostHandler)) r.HandleFunc("/logout", userHandler(LogoutHandler)) @@ -447,6 +445,6 @@ func init_route_handlers() *mux.Router { r.HandleFunc("/categories/delete", userHandler(categoryDeleteHandler)) http.Handle("/", r) - + return r } diff --git a/templates.go b/templates.go index 9112710..af32df6 100644 --- a/templates.go +++ b/templates.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "github.com/dballard/transmet/categories" + "github.com/gorilla/csrf" + "html/template" "net/http" "path/filepath" @@ -86,7 +88,8 @@ func initTemplates() { } } -func ShowTemplate(template string, w http.ResponseWriter, data map[string]interface{}) { +func ShowTemplate(template string, w http.ResponseWriter, r *http.Request, data map[string]interface{}) { + data[csrf.TemplateTag] = csrf.TemplateField(r) err := templates[template].ExecuteTemplate(w, "layout.html", data) if err != nil { fmt.Println("Exec err: ", err) diff --git a/templates/pages/categories.html b/templates/pages/categories.html index 729203d..8e08862 100644 --- a/templates/pages/categories.html +++ b/templates/pages/categories.html @@ -14,6 +14,7 @@ {{template "select-category" dict "categories" .categories "id" -1}}
+ {{ .csrfField }}
@@ -37,6 +38,7 @@ new category select
+ {{ .csrfField }} {{if $.category.Parent.Valid }} {{template "select-category" dict "categories" .categories "id" $.category.Parent.Value}} {{else}} diff --git a/templates/pages/login.html b/templates/pages/login.html index 42b4d95..6f3014c 100644 --- a/templates/pages/login.html +++ b/templates/pages/login.html @@ -4,6 +4,7 @@ {{template "flashes" .}} +{{ .csrfField }}
{{end}} diff --git a/templates/pages/post.html b/templates/pages/post.html index d036d78..e88971f 100644 --- a/templates/pages/post.html +++ b/templates/pages/post.html @@ -16,7 +16,7 @@
Notes:
-
+
{{ .csrfField }}