119 lines
3.4 KiB
Go
119 lines
3.4 KiB
Go
|
// adapted from https://github.com/crewjam/saml/blob/main/samlidp/service.go
|
||
|
package main
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/xml"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/crewjam/saml"
|
||
|
xrv "github.com/mattermost/xml-roundtrip-validator"
|
||
|
)
|
||
|
|
||
|
// copied from https://github.com/crewjam/saml/blob/main/samlidp/util.go
|
||
|
func getSPMetadata(r io.Reader) (spMetadata *saml.EntityDescriptor, err error) {
|
||
|
var data []byte
|
||
|
if data, err = ioutil.ReadAll(r); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
spMetadata = &saml.EntityDescriptor{}
|
||
|
if err := xrv.Validate(bytes.NewBuffer(data)); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if err := xml.Unmarshal(data, &spMetadata); err != nil {
|
||
|
if err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
|
||
|
entities := &saml.EntitiesDescriptor{}
|
||
|
if err := xml.Unmarshal(data, &entities); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
for _, e := range entities.EntityDescriptors {
|
||
|
if len(e.SPSSODescriptors) > 0 {
|
||
|
return &e, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// there were no SPSSODescriptors in the response
|
||
|
return nil, errors.New("metadata contained no service provider metadata")
|
||
|
}
|
||
|
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return spMetadata, nil
|
||
|
}
|
||
|
|
||
|
type BasicServiceProviderProvider struct {
|
||
|
serviceProviders map[string]*saml.EntityDescriptor
|
||
|
}
|
||
|
|
||
|
func modifyACSBindings(entityDescriptor *saml.EntityDescriptor) {
|
||
|
// The SAML library only allows the HTTP-POST Binding (from the IdP
|
||
|
// to the SP), so we need to modify the AssertionConsumerService
|
||
|
// endpoints which use HTTP-Redirect to use HTTP-POST.
|
||
|
for i := 0; i < len(entityDescriptor.SPSSODescriptors); i++ {
|
||
|
spSSODescriptor := &entityDescriptor.SPSSODescriptors[i]
|
||
|
for j := 0; j < len(spSSODescriptor.AssertionConsumerServices); j++ {
|
||
|
acs := &spSSODescriptor.AssertionConsumerServices[j]
|
||
|
if acs.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" {
|
||
|
log.Printf("Replacing Binding for %s from HTTP-Redirect to HTTP-POST", acs.Location)
|
||
|
acs.Binding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func NewServiceProviderProvider(spMetadataPaths []string) saml.ServiceProviderProvider {
|
||
|
spp := &BasicServiceProviderProvider{
|
||
|
serviceProviders: make(map[string]*saml.EntityDescriptor),
|
||
|
}
|
||
|
client := http.Client{}
|
||
|
loadSPMetadata := func(filename string) (*saml.EntityDescriptor, error) {
|
||
|
var r io.ReadCloser
|
||
|
var err error
|
||
|
if strings.HasPrefix(filename, "http://") || strings.HasPrefix(filename, "https://") {
|
||
|
resp, err := client.Get(filename)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
r = resp.Body
|
||
|
} else {
|
||
|
r, err = os.Open(filename)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
defer r.Close()
|
||
|
return getSPMetadata(r)
|
||
|
}
|
||
|
for _, filename := range spMetadataPaths {
|
||
|
metadata, err := loadSPMetadata(filename)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
modifyACSBindings(metadata)
|
||
|
spp.serviceProviders[metadata.EntityID] = metadata
|
||
|
}
|
||
|
return spp
|
||
|
}
|
||
|
|
||
|
// GetServiceProvider returns the Service Provider metadata for the
|
||
|
// service provider ID, which is typically the service provider's
|
||
|
// metadata URL. If an appropriate service provider cannot be found then
|
||
|
// the returned error must be os.ErrNotExist.
|
||
|
func (s *BasicServiceProviderProvider) GetServiceProvider(r *http.Request, serviceProviderID string) (*saml.EntityDescriptor, error) {
|
||
|
rv, ok := s.serviceProviders[serviceProviderID]
|
||
|
if !ok {
|
||
|
return nil, os.ErrNotExist
|
||
|
}
|
||
|
return rv, nil
|
||
|
}
|