pyceo/web/internal/api/middleware.go

94 lines
2.4 KiB
Go

package api
import (
"errors"
"net/http"
"strconv"
"strings"
"github.com/labstack/echo/v4"
"git.csclub.uwaterloo.ca/public/pyceo/web/internal/app"
"git.csclub.uwaterloo.ca/public/pyceo/web/internal/config"
"git.csclub.uwaterloo.ca/public/pyceo/web/pkg/logging"
)
func helmet(cfg *config.Config) echo.MiddlewareFunc {
cspSchemes := "https:"
if cfg.IsDev {
cspSchemes = "http: https:"
}
cspDirectives := []string{
"default-src 'self'",
"base-uri 'self'",
"font-src 'self' " + cspSchemes + " data:",
"form-action 'self'",
"frame-ancestors 'self'",
"img-src 'self' data:",
"object-src 'none'",
"script-src 'self'",
"script-src-attr 'none'",
"style-src 'self' " + cspSchemes + " 'unsafe-inline'",
}
if !cfg.IsDev {
cspDirectives = append(cspDirectives, "upgrade-insecure-requests")
}
csp := strings.Join(cspDirectives, ";")
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
h := c.Response().Header()
h.Set(echo.HeaderContentSecurityPolicy, csp)
h.Set("Cross-Origin-Opener-Policy", "same-origin")
h.Set("Cross-Origin-Resource-Policy", "same-origin")
h.Set(echo.HeaderReferrerPolicy, "no-referrer")
if cfg.HstsMaxAge != 0 {
h.Set(
echo.HeaderStrictTransportSecurity,
"max-age="+strconv.FormatInt(int64(cfg.HstsMaxAge), 10),
)
}
return next(c)
}
}
}
func getReqInfoFromHTTPHeaders(r *http.Request) (*app.ReqInfo, error) {
// header names must be in canonical form (see http.CanonicalHeaderKey)
usernames := r.Header["X-Csc-Adfs-Username"]
if len(usernames) == 0 {
return nil, errors.New("Username is missing from HTTP headers")
}
givenNames := r.Header["X-Csc-Adfs-Firstname"]
if len(givenNames) == 0 {
return nil, errors.New("Given name is missing from HTTP headers")
}
return &app.ReqInfo{Username: usernames[0], GivenName: givenNames[0]}, nil
}
type appContext struct {
echo.Context
req *app.ReqInfo
app *app.App
}
func (ac *appContext) Log() logging.Logger {
return ac.Context.Logger()
}
func (ac *appContext) Req() *app.ReqInfo {
return ac.req
}
func appContextMiddleware(app *app.App) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
reqInfo, err := getReqInfoFromHTTPHeaders(c.Request())
if err != nil {
return err
}
ac := &appContext{Context: c, req: reqInfo, app: app}
return next(ac)
}
}
}