// adapted from https://github.com/crewjam/saml/blob/main/samlidp/service.go package main import ( "bytes" "encoding/xml" "errors" "io" "io/ioutil" "log" "net/http" "os" "strings" "github.com/crewjam/saml" xrv "github.com/mattermost/xml-roundtrip-validator" ) // copied from https://github.com/crewjam/saml/blob/main/samlidp/util.go func getSPMetadata(r io.Reader) (spMetadata *saml.EntityDescriptor, err error) { var data []byte if data, err = ioutil.ReadAll(r); err != nil { return nil, err } spMetadata = &saml.EntityDescriptor{} if err := xrv.Validate(bytes.NewBuffer(data)); err != nil { return nil, err } if err := xml.Unmarshal(data, &spMetadata); err != nil { if err.Error() == "expected element type but have " { entities := &saml.EntitiesDescriptor{} if err := xml.Unmarshal(data, &entities); err != nil { return nil, err } for _, e := range entities.EntityDescriptors { if len(e.SPSSODescriptors) > 0 { return &e, nil } } // there were no SPSSODescriptors in the response return nil, errors.New("metadata contained no service provider metadata") } return nil, err } return spMetadata, nil } type BasicServiceProviderProvider struct { serviceProviders map[string]*saml.EntityDescriptor } func modifyACSBindings(entityDescriptor *saml.EntityDescriptor) { // The SAML library only allows the HTTP-POST Binding (from the IdP // to the SP), so we need to modify the AssertionConsumerService // endpoints which use HTTP-Redirect to use HTTP-POST. for i := 0; i < len(entityDescriptor.SPSSODescriptors); i++ { spSSODescriptor := &entityDescriptor.SPSSODescriptors[i] for j := 0; j < len(spSSODescriptor.AssertionConsumerServices); j++ { acs := &spSSODescriptor.AssertionConsumerServices[j] if acs.Binding == "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" { log.Printf("Replacing Binding for %s from HTTP-Redirect to HTTP-POST", acs.Location) acs.Binding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" } } } } func NewServiceProviderProvider(spMetadataPaths []string) saml.ServiceProviderProvider { spp := &BasicServiceProviderProvider{ serviceProviders: make(map[string]*saml.EntityDescriptor), } client := http.Client{} loadSPMetadata := func(filename string) (*saml.EntityDescriptor, error) { var r io.ReadCloser var err error if strings.HasPrefix(filename, "http://") || strings.HasPrefix(filename, "https://") { resp, err := client.Get(filename) if err != nil { return nil, err } r = resp.Body } else { r, err = os.Open(filename) if err != nil { return nil, err } } defer r.Close() return getSPMetadata(r) } for _, filename := range spMetadataPaths { metadata, err := loadSPMetadata(filename) if err != nil { panic(err) } modifyACSBindings(metadata) spp.serviceProviders[metadata.EntityID] = metadata } return spp } // GetServiceProvider returns the Service Provider metadata for the // service provider ID, which is typically the service provider's // metadata URL. If an appropriate service provider cannot be found then // the returned error must be os.ErrNotExist. func (s *BasicServiceProviderProvider) GetServiceProvider(r *http.Request, serviceProviderID string) (*saml.EntityDescriptor, error) { rv, ok := s.serviceProviders[serviceProviderID] if !ok { return nil, os.ErrNotExist } return rv, nil }