From 254bf100549f70a75ec312cfd2b3421c95ea8b4e Mon Sep 17 00:00:00 2001 From: Jan Barfuss Date: Fri, 20 Dec 2024 15:19:33 +0100 Subject: [PATCH] User auth added --- config/config.go | 15 +++- database/database.go | 12 +++- go.mod | 1 + go.sum | 2 + server/auth.go | 164 +++++++++++++++++++++++++++++++++++++++++++ server/server.go | 41 +++++++---- 6 files changed, 219 insertions(+), 16 deletions(-) create mode 100644 server/auth.go diff --git a/config/config.go b/config/config.go index 83197ae..cbe0b34 100644 --- a/config/config.go +++ b/config/config.go @@ -10,7 +10,7 @@ import ( func Load() map[string]string { var env map[string]string = make(map[string]string) - validEnv := []string{"DB_USER", "DB_PASSWORD", "DB_NAME", "DB_HOST", "DB_PORT", "PORT"} + validEnv := []string{"DB_USER", "DB_PASSWORD", "DB_NAME", "DB_HOST", "DB_PORT", "PORT", "SECRET"} envpath := "./.env" @@ -39,9 +39,22 @@ func Load() map[string]string { } checkDB(env) + checkSecret(env) return env } +func checkSecret(env map[string]string) { + if _, ok := env["SECRET"]; !ok { + fmt.Println("SECRET is not set") + os.Exit(1) + } + + if len(env["SECRET"]) < 32 { + fmt.Println("SECRET is too short") + os.Exit(1) + } +} + func checkDB(env map[string]string) { required := []string{"DB_USER", "DB_PASSWORD", "DB_HOST", "DB_PORT"} optional := []string{"DB_NAME"} diff --git a/database/database.go b/database/database.go index 511f75c..47124ce 100644 --- a/database/database.go +++ b/database/database.go @@ -187,7 +187,7 @@ func GetTransactions(database *sql.DB, account int, year int, month int) ([]Tran return nil, errors.New("year is required") } - // Extract is probably not used right + // TODO: Extract is probably not used right if month == 0 { row, err = database.Query("SELECT * FROM transactions WHERE account = $1 AND EXTRACT(YEAR FROM time) = $2", account, year) } else { @@ -248,6 +248,16 @@ func GetUser(database *sql.DB, id int) (User, error) { return user, nil } +func GetUserByName(database *sql.DB, name string) (User, error) { + var user User + err := database.QueryRow("SELECT * FROM users WHERE name = $1", name).Scan(&user.ID, &user.Name, &user.Password) + if err != nil { + return user, err + } + return user, nil + +} + func AuthenicateUser(database *sql.DB, name string, password string) (User, error) { var user User err := database.QueryRow("SELECT * FROM users WHERE name = $1 AND password = $2", name, password).Scan(&user.ID, &user.Name, &user.Password) diff --git a/go.mod b/go.mod index a225dbe..4ff98ec 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.23.0 // indirect github.com/goccy/go-json v0.10.4 // indirect + github.com/golang-jwt/jwt/v4 v4.5.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index 1c806a6..78749da 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM= github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= +github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= diff --git a/server/auth.go b/server/auth.go new file mode 100644 index 0000000..761a907 --- /dev/null +++ b/server/auth.go @@ -0,0 +1,164 @@ +package server + +import ( + "fmt" + "net/http" + "strings" + "time" + + "github.com/LeRoid-hub/Bookholder-API/database" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v4" + "golang.org/x/crypto/bcrypt" +) + +type AuthInput struct { + Username string `json:"username"` + Password string `json:"password"` +} + +func createUser(c *gin.Context) { + var authInput AuthInput + + if err := c.ShouldBindJSON(&authInput); err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + + var user database.User + user, err := database.GetUserByName(Database, authInput.Username) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + + if user.ID != 0 { + c.JSON(400, gin.H{"error": "User already exists"}) + return + } + + passwordHash, err := bcrypt.GenerateFromPassword([]byte(authInput.Password), bcrypt.DefaultCost) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + + user = database.User{ + Name: authInput.Username, + Password: string(passwordHash), + } + + err = database.NewUser(Database, user) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + + c.JSON(200, gin.H{"message": "User created"}) +} + +func authenticateUser(c *gin.Context) { + var authInput AuthInput + + if err := c.ShouldBindJSON(&authInput); err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + + var userFound database.User + userFound, err := database.GetUserByName(Database, authInput.Username) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + + if userFound.ID == 0 { + c.JSON(400, gin.H{"error": "User not found"}) + return + } + + err = bcrypt.CompareHashAndPassword([]byte(userFound.Password), []byte(authInput.Password)) + if err != nil { + c.JSON(400, gin.H{"error": "Invalid password"}) + return + } + + generateToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "id": userFound.ID, + "exp": time.Now().Add(time.Hour * 24).Unix(), + }) + + tokenString, err := generateToken.SignedString([]byte(Env["SECRET"])) + if err != nil { + c.JSON(400, gin.H{"error": "failed to generate token"}) + return + } + + c.JSON(200, gin.H{"token": tokenString}) +} + +func getUserProfile(c *gin.Context) { + user, _ := c.Get("currentUser") + c.JSON(200, gin.H{"user": user}) +} + +func checkAuth(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + + if authHeader == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"}) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + authToken := strings.Split(authHeader, " ") + if len(authToken) != 2 || authToken[0] != "Bearer" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"}) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + tokenString := authToken[1] + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(Env["SECRET"]), nil + }) + + if err != nil || token.Valid { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"}) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"}) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + if float64(time.Now().Unix()) > claims["exp"].(float64) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Token expired"}) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + user, err := database.GetUser(Database, int(claims["id"].(float64))) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"}) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + if user.ID == 0 { + c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"}) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + c.Set("currentUser", user) + + c.Next() +} diff --git a/server/server.go b/server/server.go index d8e1279..f36dcc4 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "database/sql" "net/http" "github.com/LeRoid-hub/Bookholder-API/database" @@ -26,30 +27,42 @@ import ( */ +var ( + Database *sql.DB + Env map[string]string +) + func Run(env map[string]string, db *database.DB) { + dbase, err := database.New() + if err != nil { + panic(err) + } + Database = dbase + r := gin.Default() v1 := r.Group("/v1") { //Account - v1.GET("/Account/:AccountID", getAccount) - v1.POST("/NewAccount", newAccount) - v1.PUT("/UpdateAccount/:AccountID", updateAccount) - v1.DELETE("/DeleteAccount/:AccountID", deleteAccount) + v1.GET("/Account/:AccountID", checkAuth, getAccount) + v1.POST("/NewAccount", checkAuth, newAccount) + v1.PUT("/UpdateAccount/:AccountID", checkAuth, updateAccount) + v1.DELETE("/DeleteAccount/:AccountID", checkAuth, deleteAccount) //Transaction - v1.GET("/Transaction/:TransactionID", getTransaction) - v1.GET("/Transactions/:AccountID/:year", getTransactions) - v1.GET("/Transactions/:AccountID/:year/:month", getTransactions) - v1.POST("/NewTransaction", newTransaction) - v1.PUT("/UpdateTransaction/:TransactionID", updateTransaction) - v1.DELETE("/DeleteTransaction/:TransactionID", deleteTransaction) + v1.GET("/Transaction/:TransactionID", checkAuth, getTransaction) + v1.GET("/Transactions/:AccountID/:year", checkAuth, getTransactions) + v1.GET("/Transactions/:AccountID/:year/:month", checkAuth, getTransactions) + v1.POST("/NewTransaction", checkAuth, newTransaction) + v1.PUT("/UpdateTransaction/:TransactionID", checkAuth, updateTransaction) + v1.DELETE("/DeleteTransaction/:TransactionID", checkAuth, deleteTransaction) //User - v1.GET("/User/:UserID", getUser) - v1.POST("/NewUser", newUser) - v1.PUT("/UpdateUser/:UserID", updateUser) - v1.DELETE("/DeleteUser/:UserID", deleteUser) + v1.GET("/User/", checkAuth, getUserProfile) + v1.POST("/NewUser", createUser, newUser) + v1.POST("/AuthenticateUser", authenticateUser) + v1.PUT("/UpdateUser/:UserID", checkAuth, updateUser) + v1.DELETE("/DeleteUser/:UserID", checkAuth, deleteUser) } r.GET("/ping", func(c *gin.Context) {