You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
saml-passthrough/service_provider_provider.go

118 lines
3.4 KiB

// adapted from https://github.com/crewjam/saml/blob/main/samlidp/service.go
package main
import (
"bytes"
"encoding/xml"
"errors"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
"github.com/crewjam/saml"
xrv "github.com/mattermost/xml-roundtrip-validator"
)
// copied from https://github.com/crewjam/saml/blob/main/samlidp/util.go
func getSPMetadata(r io.Reader) (spMetadata *saml.EntityDescriptor, err error) {
var data []byte
if data, err = ioutil.ReadAll(r); err != nil {
return nil, err
}
spMetadata = &saml.EntityDescriptor{}
if err := xrv.Validate(bytes.NewBuffer(data)); err != nil {
return nil, err
}
if err := xml.Unmarshal(data, &spMetadata); err != nil {
if err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
entities := &saml.EntitiesDescriptor{}
if err := xml.Unmarshal(data, &entities); err != nil {
return nil, err
}
for _, e := range entities.EntityDescriptors {
if len(e.SPSSODescriptors) > 0 {
return &e, nil
}
}
// there were no SPSSODescriptors in the response
return nil, errors.New("metadata contained no service provider metadata")
}
return nil, err
}
return spMetadata, nil
}
type BasicServiceProviderProvider struct {
serviceProviders map[string]*saml.EntityDescriptor
}
func modifyACSBindings(entityDescriptor *saml.EntityDescriptor) {
// The SAML library only allows the HTTP-POST Binding (from the IdP
// to the SP), so we need to modify the AssertionConsumerService
// endpoints which use HTTP-Redirect to use HTTP-POST.
for i := 0; i < len(entityDescriptor.SPSSODescriptors); i++ {
spSSODescriptor := &entityDescriptor.SPSSODescriptors[i]
for j := 0; j < len(spSSODescriptor.AssertionConsumerServices); j++ {
acs := &spSSODescriptor.AssertionConsumerServices[j]
if acs.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" {
log.Printf("Replacing Binding for %s from HTTP-Redirect to HTTP-POST", acs.Location)
acs.Binding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
}
}
}
}
func NewServiceProviderProvider(spMetadataPaths []string) saml.ServiceProviderProvider {
spp := &BasicServiceProviderProvider{
serviceProviders: make(map[string]*saml.EntityDescriptor),
}
client := http.Client{}
loadSPMetadata := func(filename string) (*saml.EntityDescriptor, error) {
var r io.ReadCloser
var err error
if strings.HasPrefix(filename, "http://") || strings.HasPrefix(filename, "https://") {
resp, err := client.Get(filename)
if err != nil {
return nil, err
}
r = resp.Body
} else {
r, err = os.Open(filename)
if err != nil {
return nil, err
}
}
defer r.Close()
return getSPMetadata(r)
}
for _, filename := range spMetadataPaths {
metadata, err := loadSPMetadata(filename)
if err != nil {
panic(err)
}
modifyACSBindings(metadata)
spp.serviceProviders[metadata.EntityID] = metadata
}
return spp
}
// GetServiceProvider returns the Service Provider metadata for the
// service provider ID, which is typically the service provider's
// metadata URL. If an appropriate service provider cannot be found then
// the returned error must be os.ErrNotExist.
func (s *BasicServiceProviderProvider) GetServiceProvider(r *http.Request, serviceProviderID string) (*saml.EntityDescriptor, error) {
rv, ok := s.serviceProviders[serviceProviderID]
if !ok {
return nil, os.ErrNotExist
}
return rv, nil
}