gin-oauth2-demo/middleware/middleware.go

432 lines
12 KiB
Go
Raw Normal View History

2025-02-12 16:01:36 +01:00
package middleware
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"git.0x0001f346.de/andreas/gin-oauth2-demo/core"
"git.0x0001f346.de/andreas/gin-oauth2-demo/repository"
"github.com/gin-gonic/gin"
"golang.org/x/oauth2"
)
type introspectionResponse struct {
Exp int `json:"exp,omitempty"`
Iat int `json:"iat,omitempty"`
AuthTime int `json:"auth_time,omitempty"`
Jti string `json:"jti,omitempty"`
Iss string `json:"iss,omitempty"`
Aud string `json:"aud,omitempty"`
Sub string `json:"sub,omitempty"`
Typ string `json:"typ,omitempty"`
Azp string `json:"azp,omitempty"`
Sid string `json:"sid,omitempty"`
Acr string `json:"acr,omitempty"`
AllowedOrigins []string `json:"allowed-origins,omitempty"`
RealmAccess struct {
Roles []string `json:"roles,omitempty"`
} `json:"realm_access,omitempty"`
ResourceAccess struct {
Account struct {
Roles []string `json:"roles,omitempty"`
} `json:"account,omitempty"`
} `json:"resource_access,omitempty"`
Scope string `json:"scope,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Email string `json:"email,omitempty"`
ClientID string `json:"client_id,omitempty"`
Username string `json:"username,omitempty"`
TokenType string `json:"token_type,omitempty"`
Active bool `json:"active"`
}
// public
const URLPrefix string = "/auth"
// private
const accessGroupNeededForThisApp string = "db-users"
const minTokenValiditySeconds float64 = 60 * 60 * 12
const nameHTTPCookie string = "token"
const urlAppCallback string = "/callback"
const urlAppLogin string = "/login"
const urlAppLogout string = "/logout"
const urlKeycloakAuth string = "https://auth.mydomain.tld/realms/myrealm/protocol/openid-connect/auth"
const urlKeycloakIOntrospect string = "https://auth.mydomain.tld/realms/myrealm/protocol/openid-connect/token/introspect"
const urlKeycloakLogout string = "https://auth.mydomain.tld/realms/myrealm/protocol/openid-connect/logout"
const urlKeycloakToken string = "https://auth.mydomain.tld/realms/myrealm/protocol/openid-connect/token"
const urlKeycloakUserinfo string = "https://auth.mydomain.tld/realms/myrealm/protocol/openid-connect/userinfo"
var oAuthConfig = oauth2.Config{
ClientID: "db.mydomain.tld",
ClientSecret: "a3wFLNGDiNHyNLUM7peYLL97JBkE3ltk",
RedirectURL: fmt.Sprintf("https://db.mydomain.tld%s", GetURLCallback()),
Endpoint: oauth2.Endpoint{
AuthURL: urlKeycloakAuth,
TokenURL: urlKeycloakToken,
},
Scopes: []string{"groups", "openid", "profile", "email"},
}
var protectedURLs map[string]bool = map[string]bool{}
func Auth() gin.HandlerFunc {
return func(c *gin.Context) {
if isRouteWithoutAuth(c) {
return
}
accessToken, err := getAccessTokenFromRequest(c)
if err != nil {
deleteCookieAndRedirectToLogin(c)
return
}
c.Set("accessToken", accessToken)
user, err := getUserForAccessToken(accessToken)
if err != nil {
deleteCookieAndRedirectToLogin(c)
return
}
c.Set("user", user)
err = refreshTokenIfNecessary(c)
if err != nil {
deleteCookieAndRedirectToLogin(c)
return
}
protectURLIfNecessary(c)
}
}
func GetURLCallback() string {
return fmt.Sprintf("%s%s", URLPrefix, urlAppCallback)
}
func GetURLLogin() string {
return fmt.Sprintf("%s%s", URLPrefix, urlAppLogin)
}
func GetURLLogout() string {
return fmt.Sprintf("%s%s", URLPrefix, urlAppLogout)
}
func ProtectURL(url string) {
protectedURLs[url] = true
}
func SetupRoutes(rg *gin.RouterGroup) {
rg.GET(urlAppCallback, routeCallback)
rg.GET(urlAppLogin, routeLogin)
rg.GET(urlAppLogout, routeLogout)
}
func deleteCookieAndRedirectToLogin(c *gin.Context) {
c.SetCookie(nameHTTPCookie, "", -1, "/", "", false, true)
c.Redirect(http.StatusSeeOther, GetURLLogin())
c.Abort()
}
func fetchUserFromKeycloak(accessToken string) (core.User, error) {
req, err := http.NewRequest("GET", urlKeycloakUserinfo, nil)
if err != nil {
return core.User{}, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return core.User{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return core.User{}, fmt.Errorf("failed to get user info, status: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return core.User{}, err
}
var userInfo core.User
if err := json.Unmarshal(body, &userInfo); err != nil {
return core.User{}, err
}
return userInfo, nil
}
func getAccessTokenFromRequest(c *gin.Context) (oauth2.Token, error) {
token, err := c.Cookie(nameHTTPCookie)
if err != nil {
return oauth2.Token{}, fmt.Errorf("cookie '%s' not found", nameHTTPCookie)
}
accessToken, err := repository.GetAccessToken(token)
if err != nil {
return oauth2.Token{}, err
}
return accessToken, nil
}
func getUserForAccessToken(accessToken oauth2.Token) (core.User, error) {
uuid, err := repository.GetAccessTokenToUserMapping(accessToken.AccessToken)
if err != nil {
repository.DeleteAccessToken(accessToken.AccessToken)
return core.User{}, err
}
user, err := repository.GetUser(uuid)
if err != nil {
repository.DeleteAccessTokenToUserMapping(accessToken.AccessToken)
repository.DeleteAccessToken(accessToken.AccessToken)
return core.User{}, err
}
return user, nil
}
func introspect(token string) (introspectionResponse, error) {
data := url.Values{}
data.Set("token", token)
req, err := http.NewRequest("POST", urlKeycloakIOntrospect, strings.NewReader(data.Encode()))
if err != nil {
return introspectionResponse{}, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(oAuthConfig.ClientID, oAuthConfig.ClientSecret)
client := &http.Client{
Timeout: 5 * time.Second,
}
resp, err := client.Do(req)
if err != nil {
return introspectionResponse{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return introspectionResponse{}, fmt.Errorf("introspection failed with status: %s", resp.Status)
}
var introspectResponse introspectionResponse
if err := json.NewDecoder(resp.Body).Decode(&introspectResponse); err != nil {
return introspectionResponse{}, err
}
return introspectResponse, nil
}
func isProtectedURL(url string) bool {
_, isProtected := protectedURLs[url]
return isProtected
}
func isRouteWithoutAuth(c *gin.Context) bool {
if c.FullPath() == GetURLCallback() {
return true
}
if c.FullPath() == GetURLLogin() {
return true
}
return false
}
func logoutFromKeycloak(refreshToken string) error {
data := url.Values{}
data.Set("client_id", oAuthConfig.ClientID)
data.Set("client_secret", oAuthConfig.ClientSecret)
data.Set("refresh_token", refreshToken)
req, err := http.NewRequest("POST", urlKeycloakLogout, strings.NewReader(data.Encode()))
if err != nil {
return fmt.Errorf("error when creating the logout request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("error when sending the logout request: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("logout failed, Status: %d, Body: %s", resp.StatusCode, body)
}
return nil
}
func protectURLIfNecessary(c *gin.Context) {
accessToken := c.MustGet("accessToken").(oauth2.Token)
user := c.MustGet("user").(core.User)
if !isProtectedURL(c.FullPath()) {
// URL is not protected
return
}
introspectionResponse, err := introspect(accessToken.AccessToken)
if err != nil {
// failed to introspect
repository.DeleteAccessToken(accessToken.AccessToken)
repository.DeleteAccessTokenToUserMapping(accessToken.AccessToken)
deleteCookieAndRedirectToLogin(c)
return
}
if !introspectionResponse.Active {
// session was revoked
repository.DeleteAccessToken(accessToken.AccessToken)
repository.DeleteAccessTokenToUserMapping(accessToken.AccessToken)
deleteCookieAndRedirectToLogin(c)
}
if introspectionResponse.DoesUserHasNecessaryGroupMembership() {
// user is golden
return
}
// user has lost its group membership
logoutFromKeycloak(accessToken.RefreshToken)
repository.DeleteAccessToken(accessToken.AccessToken)
repository.DeleteAccessTokenToUserMapping(accessToken.AccessToken)
repository.DeleteUser(user.UUID)
deleteCookieAndRedirectToLogin(c)
}
func refreshTokenIfNecessary(c *gin.Context) error {
accessToken := c.MustGet("accessToken").(oauth2.Token)
user := c.MustGet("user").(core.User)
if time.Until(accessToken.Expiry).Seconds() > minTokenValiditySeconds {
return nil
}
tokenSource := oAuthConfig.TokenSource(
context.Background(),
&oauth2.Token{
RefreshToken: accessToken.RefreshToken,
},
)
newToken, err := tokenSource.Token()
if err != nil {
logoutFromKeycloak(accessToken.RefreshToken)
repository.DeleteAccessToken(accessToken.AccessToken)
repository.DeleteAccessTokenToUserMapping(accessToken.AccessToken)
return err
}
err = repository.SetAccessToken(*newToken)
if err != nil {
return nil
}
err = repository.SetAccessTokenToUserMapping(newToken.AccessToken, user.UUID)
if err != nil {
return nil
}
repository.DeleteAccessToken(accessToken.AccessToken)
repository.DeleteAccessTokenToUserMapping(accessToken.AccessToken)
c.Set("accessToken", *newToken)
expiresIn := int(time.Until(accessToken.Expiry).Seconds())
c.SetCookie(nameHTTPCookie, accessToken.AccessToken, expiresIn, "/", "", false, true)
return nil
}
func routeCallback(c *gin.Context) {
code := c.Query("code")
if code == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Authorization code missing"})
return
}
token, err := oAuthConfig.Exchange(context.Background(), code)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"})
return
}
fetchedUser, err := fetchUserFromKeycloak(token.AccessToken)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch user info", "err": err})
return
}
if !fetchedUser.HasGroupMembership(accessGroupNeededForThisApp) {
repository.DeleteUser(fetchedUser.UUID)
repository.DeleteAccessTokenToUserMapping(token.AccessToken)
repository.DeleteAccessToken(token.AccessToken)
c.SetCookie(nameHTTPCookie, "", -1, "/", "", false, true)
c.JSON(http.StatusForbidden, gin.H{"error": "403 Forbidden"})
return
}
repository.SetUser(fetchedUser)
repository.SetAccessToken(*token)
repository.SetAccessTokenToUserMapping(token.AccessToken, fetchedUser.UUID)
expiresIn := int(time.Until(token.Expiry).Seconds())
c.SetCookie(nameHTTPCookie, token.AccessToken, expiresIn, "/", "", false, true)
c.Redirect(http.StatusSeeOther, "/")
}
func routeLogin(c *gin.Context) {
c.Redirect(
http.StatusSeeOther,
oAuthConfig.AuthCodeURL(
"random-state-string",
oauth2.AccessTypeOffline,
),
)
}
func routeLogout(c *gin.Context) {
accessToken := c.MustGet("accessToken").(oauth2.Token)
logoutFromKeycloak(accessToken.RefreshToken)
c.SetCookie(nameHTTPCookie, "", -1, "/", "", false, true)
c.Redirect(http.StatusSeeOther, urlAppLogin)
}
func (r introspectionResponse) DoesUserHasNecessaryGroupMembership() bool {
for _, g := range r.Groups {
if g == accessGroupNeededForThisApp {
return true
}
}
return false
}