diff --git a/saml_test.go b/saml_test.go index d59a5e1..6141435 100644 --- a/saml_test.go +++ b/saml_test.go @@ -25,8 +25,8 @@ import ( "encoding/pem" "encoding/xml" "fmt" - "io/ioutil" "log" + "os" "testing" "github.com/beevik/etree" @@ -49,7 +49,7 @@ func init() { } func TestDecode(t *testing.T) { - f, err := ioutil.ReadFile("./testdata/saml.post") + f, err := os.ReadFile("./testdata/saml.post") if err != nil { t.Fatalf("could not open test file: %v\n", err) } @@ -65,7 +65,7 @@ func TestDecode(t *testing.T) { ea := response.EncryptedAssertions[0] - k, err := ea.EncryptedKey.DecryptSymmetricKey(&cert) + k, err := ea.EncryptedKey.DecryptSymmetricKey(&cert, ea.EncryptionMethod.Algorithm) if err != nil { t.Fatalf("could not get symmetric key: %v\n", err) } @@ -79,7 +79,7 @@ func TestDecode(t *testing.T) { t.Fatalf("error decrypting saml data: %v\n", err) } - f2, err := ioutil.ReadFile("./testdata/saml.xml") + f2, err := os.ReadFile("./testdata/saml.xml") if err != nil { t.Fatalf("could not read expected output") } diff --git a/types/encrypted_assertion.go b/types/encrypted_assertion.go index 150e505..02cdd17 100644 --- a/types/encrypted_assertion.go +++ b/types/encrypted_assertion.go @@ -44,7 +44,7 @@ func (ea *EncryptedAssertion) DecryptBytes(cert *tls.Certificate) ([]byte, error // https://www.w3.org/TR/2002/REC-xmlenc-core-20021210/Overview.html#sec-Extensions-to-KeyInfo ek = &ea.DetEncryptedKey } - k, err := ek.DecryptSymmetricKey(cert) + k, err := ek.DecryptSymmetricKey(cert, ea.EncryptionMethod.Algorithm) if err != nil { return nil, fmt.Errorf("cannot decrypt, error retrieving private key: %s", err) } diff --git a/types/encrypted_key.go b/types/encrypted_key.go index 82db74b..12ffbae 100644 --- a/types/encrypted_key.go +++ b/types/encrypted_key.go @@ -18,6 +18,7 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "crypto/des" "crypto/rand" "crypto/rsa" "crypto/sha1" @@ -31,8 +32,8 @@ import ( "strings" ) -//EncryptedKey contains the decryption key data from the saml2 core and xmlenc -//standards. +// EncryptedKey contains the decryption key data from the saml2 core and xmlenc +// standards. type EncryptedKey struct { // EncryptionMethod string `xml:"EncryptionMethod>Algorithm"` X509Data string `xml:"KeyInfo>X509Data>X509Certificate"` @@ -40,7 +41,7 @@ type EncryptedKey struct { EncryptionMethod EncryptionMethod } -//EncryptionMethod specifies the type of encryption that was used. +// EncryptionMethod specifies the type of encryption that was used. type EncryptionMethod struct { Algorithm string `xml:",attr,omitempty"` //Digest method is present for algorithms like RSA-OAEP. @@ -51,19 +52,19 @@ type EncryptionMethod struct { DigestMethod *DigestMethod `xml:",omitempty"` } -//DigestMethod is a digest type specification +// DigestMethod is a digest type specification type DigestMethod struct { Algorithm string `xml:",attr,omitempty"` } -//Well-known public-key encryption methods +// Well-known public-key encryption methods const ( MethodRSAOAEP = "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p" MethodRSAOAEP2 = "http://www.w3.org/2009/xmlenc11#rsa-oaep" MethodRSAv1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5" ) -//Well-known private key encryption methods +// Well-known private key encryption methods const ( MethodAES128GCM = "http://www.w3.org/2009/xmlenc11#aes128-gcm" MethodAES192GCM = "http://www.w3.org/2009/xmlenc11#aes192-gcm" @@ -73,15 +74,15 @@ const ( MethodTripleDESCBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc" ) -//Well-known hash methods +// Well-known hash methods const ( MethodSHA1 = "http://www.w3.org/2000/09/xmldsig#sha1" MethodSHA256 = "http://www.w3.org/2000/09/xmldsig#sha256" MethodSHA512 = "http://www.w3.org/2000/09/xmldsig#sha512" ) -//SHA-1 is commonly used for certificate fingerprints (openssl -fingerprint and ADFS thumbprint). -//SHA-1 is sufficient for our purposes here (error message). +// SHA-1 is commonly used for certificate fingerprints (openssl -fingerprint and ADFS thumbprint). +// SHA-1 is sufficient for our purposes here (error message). func debugKeyFp(keyBytes []byte) string { if len(keyBytes) < 1 { return "" @@ -100,8 +101,8 @@ func debugKeyFp(keyBytes []byte) string { return ret } -//DecryptSymmetricKey returns the private key contained in the EncryptedKey document -func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block, error) { +// DecryptSymmetricKey returns the private key contained in the EncryptedKey document +func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate, algo string) (cipher.Block, error) { if len(cert.Certificate) < 1 { return nil, fmt.Errorf("decryption tls.Certificate has no public certs attached") } @@ -146,6 +147,14 @@ func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block } } + cipher := func(k []byte) (cipher.Block, error) { + if algo == MethodTripleDESCBC { + return des.NewTripleDESCipher(k) + } else { + return aes.NewCipher(k) + } + } + switch ek.EncryptionMethod.Algorithm { case "": return nil, fmt.Errorf("missing encryption algorithm") @@ -155,7 +164,7 @@ func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block return nil, fmt.Errorf("rsa internal error: %v", err) } - b, err := aes.NewCipher(pt) + b, err := cipher(pt) if err != nil { return nil, err } @@ -175,7 +184,7 @@ func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block //The RSA v1.5 Key Transport algorithm given below are those used in conjunction with TRIPLEDES //Please also see https://www.w3.org/TR/xmlenc-core/#sec-Algorithms and //https://www.w3.org/TR/xmlenc-core/#rsav15note. - b, err := aes.NewCipher(pt) + b, err := cipher(pt) if err != nil { return nil, err }