User auth added

This commit is contained in:
Jan Barfuss 2024-12-20 15:19:33 +01:00
parent a4ff21b42c
commit 254bf10054
6 changed files with 219 additions and 16 deletions

View File

@ -10,7 +10,7 @@ import (
func Load() map[string]string { func Load() map[string]string {
var env map[string]string = make(map[string]string) var env map[string]string = make(map[string]string)
validEnv := []string{"DB_USER", "DB_PASSWORD", "DB_NAME", "DB_HOST", "DB_PORT", "PORT"} validEnv := []string{"DB_USER", "DB_PASSWORD", "DB_NAME", "DB_HOST", "DB_PORT", "PORT", "SECRET"}
envpath := "./.env" envpath := "./.env"
@ -39,9 +39,22 @@ func Load() map[string]string {
} }
checkDB(env) checkDB(env)
checkSecret(env)
return env return env
} }
func checkSecret(env map[string]string) {
if _, ok := env["SECRET"]; !ok {
fmt.Println("SECRET is not set")
os.Exit(1)
}
if len(env["SECRET"]) < 32 {
fmt.Println("SECRET is too short")
os.Exit(1)
}
}
func checkDB(env map[string]string) { func checkDB(env map[string]string) {
required := []string{"DB_USER", "DB_PASSWORD", "DB_HOST", "DB_PORT"} required := []string{"DB_USER", "DB_PASSWORD", "DB_HOST", "DB_PORT"}
optional := []string{"DB_NAME"} optional := []string{"DB_NAME"}

View File

@ -187,7 +187,7 @@ func GetTransactions(database *sql.DB, account int, year int, month int) ([]Tran
return nil, errors.New("year is required") return nil, errors.New("year is required")
} }
// Extract is probably not used right // TODO: Extract is probably not used right
if month == 0 { if month == 0 {
row, err = database.Query("SELECT * FROM transactions WHERE account = $1 AND EXTRACT(YEAR FROM time) = $2", account, year) row, err = database.Query("SELECT * FROM transactions WHERE account = $1 AND EXTRACT(YEAR FROM time) = $2", account, year)
} else { } else {
@ -248,6 +248,16 @@ func GetUser(database *sql.DB, id int) (User, error) {
return user, nil return user, nil
} }
func GetUserByName(database *sql.DB, name string) (User, error) {
var user User
err := database.QueryRow("SELECT * FROM users WHERE name = $1", name).Scan(&user.ID, &user.Name, &user.Password)
if err != nil {
return user, err
}
return user, nil
}
func AuthenicateUser(database *sql.DB, name string, password string) (User, error) { func AuthenicateUser(database *sql.DB, name string, password string) (User, error) {
var user User var user User
err := database.QueryRow("SELECT * FROM users WHERE name = $1 AND password = $2", name, password).Scan(&user.ID, &user.Name, &user.Password) err := database.QueryRow("SELECT * FROM users WHERE name = $1 AND password = $2", name, password).Scan(&user.ID, &user.Name, &user.Password)

1
go.mod
View File

@ -16,6 +16,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.23.0 // indirect github.com/go-playground/validator/v10 v10.23.0 // indirect
github.com/goccy/go-json v0.10.4 // indirect github.com/goccy/go-json v0.10.4 // indirect
github.com/golang-jwt/jwt/v4 v4.5.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect

2
go.sum
View File

@ -23,6 +23,8 @@ github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL
github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM= github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM=
github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=

164
server/auth.go Normal file
View File

@ -0,0 +1,164 @@
package server
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/LeRoid-hub/Bookholder-API/database"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"golang.org/x/crypto/bcrypt"
)
type AuthInput struct {
Username string `json:"username"`
Password string `json:"password"`
}
func createUser(c *gin.Context) {
var authInput AuthInput
if err := c.ShouldBindJSON(&authInput); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
var user database.User
user, err := database.GetUserByName(Database, authInput.Username)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
if user.ID != 0 {
c.JSON(400, gin.H{"error": "User already exists"})
return
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(authInput.Password), bcrypt.DefaultCost)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
user = database.User{
Name: authInput.Username,
Password: string(passwordHash),
}
err = database.NewUser(Database, user)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
c.JSON(200, gin.H{"message": "User created"})
}
func authenticateUser(c *gin.Context) {
var authInput AuthInput
if err := c.ShouldBindJSON(&authInput); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
var userFound database.User
userFound, err := database.GetUserByName(Database, authInput.Username)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
if userFound.ID == 0 {
c.JSON(400, gin.H{"error": "User not found"})
return
}
err = bcrypt.CompareHashAndPassword([]byte(userFound.Password), []byte(authInput.Password))
if err != nil {
c.JSON(400, gin.H{"error": "Invalid password"})
return
}
generateToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"id": userFound.ID,
"exp": time.Now().Add(time.Hour * 24).Unix(),
})
tokenString, err := generateToken.SignedString([]byte(Env["SECRET"]))
if err != nil {
c.JSON(400, gin.H{"error": "failed to generate token"})
return
}
c.JSON(200, gin.H{"token": tokenString})
}
func getUserProfile(c *gin.Context) {
user, _ := c.Get("currentUser")
c.JSON(200, gin.H{"user": user})
}
func checkAuth(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"})
c.AbortWithStatus(http.StatusUnauthorized)
return
}
authToken := strings.Split(authHeader, " ")
if len(authToken) != 2 || authToken[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"})
c.AbortWithStatus(http.StatusUnauthorized)
return
}
tokenString := authToken[1]
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(Env["SECRET"]), nil
})
if err != nil || token.Valid {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
c.AbortWithStatus(http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"})
c.AbortWithStatus(http.StatusUnauthorized)
return
}
if float64(time.Now().Unix()) > claims["exp"].(float64) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token expired"})
c.AbortWithStatus(http.StatusUnauthorized)
return
}
user, err := database.GetUser(Database, int(claims["id"].(float64)))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
c.AbortWithStatus(http.StatusUnauthorized)
return
}
if user.ID == 0 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.Set("currentUser", user)
c.Next()
}

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"database/sql"
"net/http" "net/http"
"github.com/LeRoid-hub/Bookholder-API/database" "github.com/LeRoid-hub/Bookholder-API/database"
@ -26,30 +27,42 @@ import (
*/ */
var (
Database *sql.DB
Env map[string]string
)
func Run(env map[string]string, db *database.DB) { func Run(env map[string]string, db *database.DB) {
dbase, err := database.New()
if err != nil {
panic(err)
}
Database = dbase
r := gin.Default() r := gin.Default()
v1 := r.Group("/v1") v1 := r.Group("/v1")
{ {
//Account //Account
v1.GET("/Account/:AccountID", getAccount) v1.GET("/Account/:AccountID", checkAuth, getAccount)
v1.POST("/NewAccount", newAccount) v1.POST("/NewAccount", checkAuth, newAccount)
v1.PUT("/UpdateAccount/:AccountID", updateAccount) v1.PUT("/UpdateAccount/:AccountID", checkAuth, updateAccount)
v1.DELETE("/DeleteAccount/:AccountID", deleteAccount) v1.DELETE("/DeleteAccount/:AccountID", checkAuth, deleteAccount)
//Transaction //Transaction
v1.GET("/Transaction/:TransactionID", getTransaction) v1.GET("/Transaction/:TransactionID", checkAuth, getTransaction)
v1.GET("/Transactions/:AccountID/:year", getTransactions) v1.GET("/Transactions/:AccountID/:year", checkAuth, getTransactions)
v1.GET("/Transactions/:AccountID/:year/:month", getTransactions) v1.GET("/Transactions/:AccountID/:year/:month", checkAuth, getTransactions)
v1.POST("/NewTransaction", newTransaction) v1.POST("/NewTransaction", checkAuth, newTransaction)
v1.PUT("/UpdateTransaction/:TransactionID", updateTransaction) v1.PUT("/UpdateTransaction/:TransactionID", checkAuth, updateTransaction)
v1.DELETE("/DeleteTransaction/:TransactionID", deleteTransaction) v1.DELETE("/DeleteTransaction/:TransactionID", checkAuth, deleteTransaction)
//User //User
v1.GET("/User/:UserID", getUser) v1.GET("/User/", checkAuth, getUserProfile)
v1.POST("/NewUser", newUser) v1.POST("/NewUser", createUser, newUser)
v1.PUT("/UpdateUser/:UserID", updateUser) v1.POST("/AuthenticateUser", authenticateUser)
v1.DELETE("/DeleteUser/:UserID", deleteUser) v1.PUT("/UpdateUser/:UserID", checkAuth, updateUser)
v1.DELETE("/DeleteUser/:UserID", checkAuth, deleteUser)
} }
r.GET("/ping", func(c *gin.Context) { r.GET("/ping", func(c *gin.Context) {