diff --git a/README.md b/README.md index 9aa4e119..74c446e1 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ docker-compose up | `BIND_ADDRESS` | The addresses that can access to the web interface and the port, use unix:///abspath/to/file.socket for unix domain socket. | 0.0.0.0:80 | | `SESSION_SECRET` | The secret key used to encrypt the session cookies. Set this to a random value | N/A | | `SESSION_SECRET_FILE` | Optional filepath for the secret key used to encrypt the session cookies. Leave `SESSION_SECRET` blank to take effect | N/A | +| `SESSION_MAX_DURATION` | Max time in days a remembered session is refreshed and valid. Non-refreshed session is valid for 7 days max, regardless of this setting. | 90 | | `SUBNET_RANGES` | The list of address subdivision ranges. Format: `SR Name:10.0.1.0/24; SR2:10.0.2.0/24,10.0.3.0/24` Each CIDR must be inside one of the server interfaces. | N/A | | `WGUI_USERNAME` | The username for the login page. Used for db initialization only | `admin` | | `WGUI_PASSWORD` | The password for the user on the login page. Will be hashed automatically. Used for db initialization only | `admin` | diff --git a/handler/routes.go b/handler/routes.go index 513b0a6a..ef01d086 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -93,32 +93,41 @@ func Login(db store.IStore) echo.HandlerFunc { } if userCorrect && passwordCorrect { - // TODO: refresh the token ageMax := 0 - expiration := time.Now().Add(24 * time.Hour) if rememberMe { - ageMax = 86400 - expiration.Add(144 * time.Hour) + ageMax = 86400 * 7 } + + cookiePath := util.GetCookiePath() + sess, _ := session.Get("session", c) sess.Options = &sessions.Options{ - Path: util.BasePath, + Path: cookiePath, MaxAge: ageMax, HttpOnly: true, + SameSite: http.SameSiteLaxMode, } // set session_token tokenUID := xid.New().String() + now := time.Now().UTC().Unix() sess.Values["username"] = dbuser.Username + sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser) sess.Values["admin"] = dbuser.Admin sess.Values["session_token"] = tokenUID + sess.Values["max_age"] = ageMax + sess.Values["created_at"] = now + sess.Values["updated_at"] = now sess.Save(c.Request(), c.Response()) // set session_token in cookie cookie := new(http.Cookie) cookie.Name = "session_token" + cookie.Path = cookiePath cookie.Value = tokenUID - cookie.Expires = expiration + cookie.MaxAge = ageMax + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteLaxMode c.SetCookie(cookie) return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Logged in successfully"}) @@ -256,7 +265,7 @@ func UpdateUser(db store.IStore) echo.HandlerFunc { log.Infof("Updated user information successfully") if previousUsername == currentUser(c) { - setUser(c, user.Username, user.Admin) + setUser(c, user.Username, user.Admin, util.GetDBUserCRC32(user)) } return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Updated user information successfully"}) diff --git a/handler/session.go b/handler/session.go index 4cede6e1..b660d9ca 100644 --- a/handler/session.go +++ b/handler/session.go @@ -3,7 +3,9 @@ package handler import ( "fmt" "net/http" + "time" + "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" "github.com/ngoduykhanh/wireguard-ui/util" @@ -23,6 +25,15 @@ func ValidSession(next echo.HandlerFunc) echo.HandlerFunc { } } +// RefreshSession must only be used after ValidSession middleware +// RefreshSession checks if the session is eligible for the refresh, but doesn't check if it's fully valid +func RefreshSession(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + doRefreshSession(c) + return next(c) + } +} + func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if !isAdmin(c) { @@ -41,9 +52,146 @@ func isValidSession(c echo.Context) bool { if err != nil || sess.Values["session_token"] != cookie.Value { return false } + + // Check time bounds + createdAt := getCreatedAt(sess) + updatedAt := getUpdatedAt(sess) + maxAge := getMaxAge(sess) + // Temporary session is considered valid within 24h if browser is not closed before + // This value is not saved and is used as virtual expiration + if maxAge == 0 { + maxAge = 86400 + } + expiration := updatedAt + int64(maxAge) + now := time.Now().UTC().Unix() + if updatedAt > now || expiration < now || createdAt+util.SessionMaxDuration < now { + return false + } + + // Check if user still exists and unchanged + username := fmt.Sprintf("%s", sess.Values["username"]) + userHash := getUserHash(sess) + if uHash, ok := util.DBUsersToCRC32[username]; !ok || userHash != uHash { + return false + } + return true } +// Refreshes a "remember me" session when the user visits web pages (not API) +// Session must be valid before calling this function +// Refresh is performed at most once per 24h +func doRefreshSession(c echo.Context) { + if util.DisableLogin { + return + } + + sess, _ := session.Get("session", c) + maxAge := getMaxAge(sess) + if maxAge <= 0 { + return + } + + oldCookie, err := c.Cookie("session_token") + if err != nil || sess.Values["session_token"] != oldCookie.Value { + return + } + + // Refresh no sooner than 24h + createdAt := getCreatedAt(sess) + updatedAt := getUpdatedAt(sess) + expiration := updatedAt + int64(getMaxAge(sess)) + now := time.Now().UTC().Unix() + if updatedAt > now || expiration < now || now-updatedAt < 86_400 || createdAt+util.SessionMaxDuration < now { + return + } + + cookiePath := util.GetCookiePath() + + sess.Values["updated_at"] = now + sess.Options = &sessions.Options{ + Path: cookiePath, + MaxAge: maxAge, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + sess.Save(c.Request(), c.Response()) + + cookie := new(http.Cookie) + cookie.Name = "session_token" + cookie.Path = cookiePath + cookie.Value = oldCookie.Value + cookie.MaxAge = maxAge + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteLaxMode + c.SetCookie(cookie) +} + +// Get time in seconds this session is valid without updating +func getMaxAge(sess *sessions.Session) int { + if util.DisableLogin { + return 0 + } + + maxAge := sess.Values["max_age"] + + switch typedMaxAge := maxAge.(type) { + case int: + return typedMaxAge + default: + return 0 + } +} + +// Get a timestamp in seconds of the time the session was created +func getCreatedAt(sess *sessions.Session) int64 { + if util.DisableLogin { + return 0 + } + + createdAt := sess.Values["created_at"] + + switch typedCreatedAt := createdAt.(type) { + case int64: + return typedCreatedAt + default: + return 0 + } +} + +// Get a timestamp in seconds of the last session update +func getUpdatedAt(sess *sessions.Session) int64 { + if util.DisableLogin { + return 0 + } + + lastUpdate := sess.Values["updated_at"] + + switch typedLastUpdate := lastUpdate.(type) { + case int64: + return typedLastUpdate + default: + return 0 + } +} + +// Get CRC32 of a user at the moment of log in +// Any changes to user will result in logout of other (not updated) sessions +func getUserHash(sess *sessions.Session) uint32 { + if util.DisableLogin { + return 0 + } + + userHash := sess.Values["user_hash"] + + switch typedUserHash := userHash.(type) { + case uint32: + return typedUserHash + default: + return 0 + } +} + // currentUser to get username of logged in user func currentUser(c echo.Context) string { if util.DisableLogin { @@ -66,9 +214,10 @@ func isAdmin(c echo.Context) bool { return admin == "true" } -func setUser(c echo.Context, username string, admin bool) { +func setUser(c echo.Context, username string, admin bool, userCRC32 uint32) { sess, _ := session.Get("session", c) sess.Values["username"] = username + sess.Values["user_hash"] = userCRC32 sess.Values["admin"] = admin sess.Save(c.Request(), c.Response()) } @@ -77,7 +226,24 @@ func setUser(c echo.Context, username string, admin bool) { func clearSession(c echo.Context) { sess, _ := session.Get("session", c) sess.Values["username"] = "" + sess.Values["user_hash"] = 0 sess.Values["admin"] = false sess.Values["session_token"] = "" + sess.Values["max_age"] = -1 + sess.Options.MaxAge = -1 sess.Save(c.Request(), c.Response()) + + cookiePath := util.GetCookiePath() + + cookie, err := c.Cookie("session_token") + if err != nil { + cookie = new(http.Cookie) + } + + cookie.Name = "session_token" + cookie.Path = cookiePath + cookie.MaxAge = -1 + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteLaxMode + c.SetCookie(cookie) } diff --git a/main.go b/main.go index 3d67e70e..1125746f 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "crypto/sha512" "embed" "flag" "fmt" @@ -48,6 +49,7 @@ var ( flagTelegramAllowConfRequest = false flagTelegramFloodWait = 60 flagSessionSecret = util.RandomString(32) + flagSessionMaxDuration = 90 flagWgConfTemplate string flagBasePath string flagSubnetRanges string @@ -91,6 +93,7 @@ func init() { flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.") flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL") flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.") + flag.IntVar(&flagSessionMaxDuration, "session-max-duration", util.LookupEnvOrInt("SESSION_MAX_DURATION", flagSessionMaxDuration), "Max time in days a remembered session is refreshed and valid.") var ( smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword) @@ -135,7 +138,8 @@ func init() { util.SendgridApiKey = flagSendgridApiKey util.EmailFrom = flagEmailFrom util.EmailFromName = flagEmailFromName - util.SessionSecret = []byte(flagSessionSecret) + util.SessionSecret = sha512.Sum512([]byte(flagSessionSecret)) + util.SessionMaxDuration = int64(flagSessionMaxDuration) * 86_400 // Store in seconds util.WgConfTemplate = flagWgConfTemplate util.BasePath = util.ParseBasePath(flagBasePath) util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges) @@ -204,7 +208,7 @@ func main() { // register routes app := router.New(tmplDir, extraData, util.SessionSecret) - app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession) + app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession) // Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to // mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on @@ -214,8 +218,8 @@ func main() { app.GET(util.BasePath+"/login", handler.LoginPage()) app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson) app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession) - app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession) - app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.NeedsAdmin) + app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession) + app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin) app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson) app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) @@ -241,19 +245,19 @@ func main() { app.POST(util.BasePath+"/client/set-status", handler.SetClientStatus(db), handler.ValidSession, handler.ContentTypeJson) app.POST(util.BasePath+"/remove-client", handler.RemoveClient(db), handler.ValidSession, handler.ContentTypeJson) app.GET(util.BasePath+"/download", handler.DownloadClient(db), handler.ValidSession) - app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.NeedsAdmin) + app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin) app.POST(util.BasePath+"/wg-server/interfaces", handler.WireGuardServerInterfaces(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) app.POST(util.BasePath+"/wg-server/keypair", handler.WireGuardServerKeyPair(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) - app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.NeedsAdmin) + app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin) app.POST(util.BasePath+"/global-settings", handler.GlobalSettingSubmit(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) - app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession) + app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession, handler.RefreshSession) app.GET(util.BasePath+"/api/clients", handler.GetClients(db), handler.ValidSession) app.GET(util.BasePath+"/api/client/:id", handler.GetClient(db), handler.ValidSession) app.GET(util.BasePath+"/api/machine-ips", handler.MachineIPAddresses(), handler.ValidSession) app.GET(util.BasePath+"/api/subnet-ranges", handler.GetOrderedSubnetRanges(), handler.ValidSession) app.GET(util.BasePath+"/api/suggest-client-ips", handler.SuggestIPAllocation(db), handler.ValidSession) app.POST(util.BasePath+"/api/apply-wg-config", handler.ApplyServerConfig(db, tmplDir), handler.ValidSession, handler.ContentTypeJson) - app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession) + app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession, handler.RefreshSession) app.POST(util.BasePath+"/wake_on_lan_host", handler.SaveWakeOnLanHost(db), handler.ValidSession, handler.ContentTypeJson) app.DELETE(util.BasePath+"/wake_on_lan_host/:mac_address", handler.DeleteWakeOnHost(db), handler.ValidSession, handler.ContentTypeJson) app.PUT(util.BasePath+"/wake_on_lan_host/:mac_address", handler.WakeOnHost(db), handler.ValidSession, handler.ContentTypeJson) diff --git a/router/router.go b/router/router.go index 569ebafa..59d352eb 100644 --- a/router/router.go +++ b/router/router.go @@ -48,9 +48,17 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c } // New function -func New(tmplDir fs.FS, extraData map[string]interface{}, secret []byte) *echo.Echo { +func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo.Echo { e := echo.New() - e.Use(session.Middleware(sessions.NewCookieStore(secret))) + + cookiePath := util.GetCookiePath() + + cookieStore := sessions.NewCookieStore(secret[:32], secret[32:]) + cookieStore.Options.Path = cookiePath + cookieStore.Options.HttpOnly = true + cookieStore.MaxAge(86400 * 7) + + e.Use(session.Middleware(cookieStore)) // read html template file to string tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html") diff --git a/store/jsondb/jsondb.go b/store/jsondb/jsondb.go index 5d010661..1401b2cd 100644 --- a/store/jsondb/jsondb.go +++ b/store/jsondb/jsondb.go @@ -161,6 +161,14 @@ func (o *JsonDB) Init() error { } // init cache + for _, i := range results { + user := model.User{} + + if err := json.Unmarshal([]byte(i), &user); err == nil { + util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user) + } + } + clients, err := o.GetClients(false) if err != nil { return nil @@ -214,11 +222,13 @@ func (o *JsonDB) SaveUser(user model.User) error { if err != nil { return err } + util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user) return output } // DeleteUser func to remove user from the database func (o *JsonDB) DeleteUser(username string) error { + delete(util.DBUsersToCRC32, username) return o.conn.Delete("users", username) } diff --git a/util/cache.go b/util/cache.go index b9694b92..48b37eac 100644 --- a/util/cache.go +++ b/util/cache.go @@ -5,3 +5,4 @@ import "sync" var IPToSubnetRange = map[string]uint16{} var TgUseridToClientID = map[int64][]string{} var TgUseridToClientIDMutex sync.RWMutex +var DBUsersToCRC32 = map[string]uint32{} diff --git a/util/config.go b/util/config.go index 796775c1..4af6bd2b 100644 --- a/util/config.go +++ b/util/config.go @@ -9,24 +9,25 @@ import ( // Runtime config var ( - DisableLogin bool - BindAddress string - SmtpHostname string - SmtpPort int - SmtpUsername string - SmtpPassword string - SmtpNoTLSCheck bool - SmtpEncryption string - SmtpAuthType string - SmtpHelo string - SendgridApiKey string - EmailFrom string - EmailFromName string - SessionSecret []byte - WgConfTemplate string - BasePath string - SubnetRanges map[string]([]*net.IPNet) - SubnetRangesOrder []string + DisableLogin bool + BindAddress string + SmtpHostname string + SmtpPort int + SmtpUsername string + SmtpPassword string + SmtpNoTLSCheck bool + SmtpEncryption string + SmtpAuthType string + SmtpHelo string + SendgridApiKey string + EmailFrom string + EmailFromName string + SessionSecret [64]byte + SessionMaxDuration int64 + WgConfTemplate string + BasePath string + SubnetRanges map[string]([]*net.IPNet) + SubnetRangesOrder []string ) const ( diff --git a/util/util.go b/util/util.go index 88b70899..06b87c3d 100644 --- a/util/util.go +++ b/util/util.go @@ -2,9 +2,12 @@ package util import ( "bufio" + "bytes" + "encoding/gob" "encoding/json" "errors" "fmt" + "hash/crc32" "io" "io/fs" "math/rand" @@ -827,3 +830,38 @@ func filterStringSlice(s []string, excludedStr string) []string { } return filtered } + +func GetDBUserCRC32(dbuser model.User) uint32 { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + if err := enc.Encode(dbuser); err != nil { + panic("model.User is gob-incompatible, session verification is impossible") + } + return crc32.ChecksumIEEE(buf.Bytes()) +} + +func ConcatMultipleSlices(slices ...[]byte) []byte { + var totalLen int + + for _, s := range slices { + totalLen += len(s) + } + + result := make([]byte, totalLen) + + var i int + + for _, s := range slices { + i += copy(result[i:], s) + } + + return result +} + +func GetCookiePath() string { + cookiePath := BasePath + if cookiePath == "" { + cookiePath = "/" + } + return cookiePath +}