diff --git a/config/local.json b/config/local.json index f96cdf8..9458efe 100644 --- a/config/local.json +++ b/config/local.json @@ -4,6 +4,6 @@ "Dbname": "transmet", "Username": "transmet", "Password": "asdfasdf" - } - "Port": 8001 + }, + "Port": "8001" } diff --git a/main.go b/main.go index 5eb8821..0731c8b 100644 --- a/main.go +++ b/main.go @@ -44,11 +44,16 @@ func loadConfig() { os.Exit(-1) } decoder := json.NewDecoder(file) - decoder.Decode(&config) + err = decoder.Decode(&config) + if err != nil { + fmt.Println("Error: cannot decode config file: ", err) + os.Exit(-1) + } } func dbConnect() { var err error + fmt.Println(fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=require", config.Sql.Username, config.Sql.Password, config.Sql.Host, config.Sql.Dbname)) db, err = sql.Open("postgres", fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=require", config.Sql.Username, config.Sql.Password, config.Sql.Host, config.Sql.Dbname)) if err != nil { diff --git a/route_handlers.go b/route_handlers.go index f4b22b1..d442639 100644 --- a/route_handlers.go +++ b/route_handlers.go @@ -2,10 +2,98 @@ package main import ( "github.com/gorilla/mux" + "github.com/gorilla/sessions" "net/http" - + "github.com/dballard/transmet/user" + "fmt" + "time" ) +func GetFlashes(session *sessions.Session) map[string]interface{} { + var flashes = make(map[string]interface{}) + flashes["error"] = session.Flashes(flash_err) + flashes["info"] = session.Flashes(flash_info) + return flashes +} + +func sessionWipe(session *sessions.Session) { + session.Values = make(map[interface{}]interface{}) +} + +func initSessionUser(r *http.Request) (*user.User, *sessions.Session) { + session, _ := store.Get(r, "c_user") + if session.Values["username"] == nil { + return nil, session + } + + return user.NewUserFromUsername(db, session.Values["username"].(string)), session +} + +// wrapper for handlers requiring a User +func userHandler(next func(http.ResponseWriter, *http.Request, *user.User)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + user, _ := initSessionUser(r) + if user == nil { + http.Redirect(w, r, "/", http.StatusFound) + } else { + next(w, r, user) + } + } +} + +// wrapper for handlers forking on GET and POST +// r.HandleFunc("/login", getPostHandler(LoginFormHandler, LoginPostHandler)) +func getPostHandler(getFn, postFn func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + getFn(w, r) + } else { // POST + postFn(w, r) + } + } +} + +// Log in page handler +func LoginFormHandler(w http.ResponseWriter, r *http.Request) { + session, _ := store.Get(r, "c_user") + flashes := GetFlashes(session) + session.Save(r, w) + err := templates["login"].Execute(w, map[string]interface{}{"flashes": flashes}) + if err != nil { + fmt.Println("Exec err: ", err) + } +} + +// handler for login POST +// TODO: proper per account and client flood control rate limiting +// currently weak per call slow down is by-passable at scale +func LoginPostHandler(w http.ResponseWriter, r *http.Request) { + time.Sleep(500 * time.Millisecond) // WEAK poor mans rate limiting for logins + r.ParseForm() + username := r.PostFormValue("username") + // lookup user + password := r.PostFormValue("password") + user := user.NewUserFromAuth(db, username, password) + if user != nil { + session, _ := store.Get(r, "c_user") + session.Values["username"] = user.Username + session.Save(r, w) + http.Redirect(w, r, "/home", http.StatusFound) + } else { + time.Sleep(500 * time.Millisecond) // WEAK bypassable poor mans rate limiting for failed logins + session, _ := store.Get(r, "c_user") + session.AddFlash("Username or password", flash_err) + session.Save(r, w) + http.Redirect(w, r, "/login", http.StatusFound) + } +} + + +func addFormHandler( ) { + +} + + func init_route_handlers() { http.Handle("/js/", http.StripPrefix("/js/", http.FileServer(http.Dir("js/")))) http.Handle("/css/", http.StripPrefix("/css/", http.FileServer(http.Dir("css/")))) @@ -13,6 +101,9 @@ func init_route_handlers() { r := mux.NewRouter() + r.HandleFunc("/login", getPostHandler(LoginFormHandler, LoginPostHandler)) + r.HandleFunc("/add", getPostHandler(userHandler(addFormHandler), userHandler(addPostHandler))) + r.HandleFunc("/", getPostHandler(userHandler(templateFormHandler), userHandler(templatePostHandler))) http.Handle("/", r) diff --git a/user/user.go b/user/user.go new file mode 100644 index 0000000..fe024c8 --- /dev/null +++ b/user/user.go @@ -0,0 +1,92 @@ +package user + +import ( + "code.google.com/p/go.crypto/bcrypt" + "crypto/rand" + "database/sql" + "fmt" + _ "github.com/lib/pq" +) + +func clear(b []byte) { + for i := 0; i < len(b); i++ { + b[i] = 0 + } +} + +func Crypt(password []byte) ([]byte, error) { + defer clear(password) + return bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) +} + +type User struct { + Username string + db *sql.DB +} + +func UsernameExists(db *sql.DB, username string) (bool, error) { + rows, err := db.Query("SELECT count(username) FROM users where username=$1", username) + if err != nil { + fmt.Println("User DB Error: ", err) + return false, err + } + var count int + rows.Next() + rows.Scan(&count) + return count > 0, nil +} + +func GenDisposablePassword() string { + b := make([]byte, 16) + _, err := rand.Read(b) + if err != nil { + fmt.Println("user.GenDisposablePassword() error reading from urandom: ", err) + } + return fmt.Sprintf("%x", b) +} + +func NewUserFromAuth(db *sql.DB, username, password string) *User { + fmt.Println("NewUserFromAuth:", username, ":", password) + rows, err := db.Query("SELECT password FROM users WHERE username = $1", username) + if err != nil { + fmt.Println("User DB Error: ", err) + return nil + } + var hash_db string + user := User{db: db} + + if rows.Next() { + var pw sql.NullString + err := rows.Scan(&user.Username, &pw) + if err != nil { + fmt.Println("scan err: ", err) + } + hash_db = pw.String + } else { + return nil + } + if err = bcrypt.CompareHashAndPassword([]byte(hash_db), []byte(password)); err == nil { + return &user + } + fmt.Println("auth fail:", err) + return nil +} + +func NewUserFromUsername(db *sql.DB, username string) *User { + rows, err := db.Query("SELECT username FROM users WHERE username=$1", username) + if err != nil { + fmt.Println("User DB Error: ", err) + return nil + } + user := User{db: db} + if rows.Next() { + err = rows.Scan(&user.Username) + if err != nil { + fmt.Println("Scan err: ", err) + } + } else { + fmt.Println("User DB Error: No user found with username ", username) + return nil + } + return &user +}