Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix http://www.w3.org/2001/04/xmlenc#tripledes-cbc encrypted assertions #198

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions saml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
"encoding/pem"
"encoding/xml"
"fmt"
"io/ioutil"
"log"
"os"
"testing"

"github.com/beevik/etree"
Expand All @@ -49,7 +49,7 @@ func init() {
}

func TestDecode(t *testing.T) {
f, err := ioutil.ReadFile("./testdata/saml.post")
f, err := os.ReadFile("./testdata/saml.post")
if err != nil {
t.Fatalf("could not open test file: %v\n", err)
}
Expand All @@ -65,7 +65,7 @@ func TestDecode(t *testing.T) {

ea := response.EncryptedAssertions[0]

k, err := ea.EncryptedKey.DecryptSymmetricKey(&cert)
k, err := ea.EncryptedKey.DecryptSymmetricKey(&cert, ea.EncryptionMethod.Algorithm)
if err != nil {
t.Fatalf("could not get symmetric key: %v\n", err)
}
Expand All @@ -79,7 +79,7 @@ func TestDecode(t *testing.T) {
t.Fatalf("error decrypting saml data: %v\n", err)
}

f2, err := ioutil.ReadFile("./testdata/saml.xml")
f2, err := os.ReadFile("./testdata/saml.xml")
if err != nil {
t.Fatalf("could not read expected output")
}
Expand Down
2 changes: 1 addition & 1 deletion types/encrypted_assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (ea *EncryptedAssertion) DecryptBytes(cert *tls.Certificate) ([]byte, error
// https://www.w3.org/TR/2002/REC-xmlenc-core-20021210/Overview.html#sec-Extensions-to-KeyInfo
ek = &ea.DetEncryptedKey
}
k, err := ek.DecryptSymmetricKey(cert)
k, err := ek.DecryptSymmetricKey(cert, ea.EncryptionMethod.Algorithm)
if err != nil {
return nil, fmt.Errorf("cannot decrypt, error retrieving private key: %s", err)
}
Expand Down
35 changes: 22 additions & 13 deletions types/encrypted_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
Expand All @@ -31,16 +32,16 @@ import (
"strings"
)

//EncryptedKey contains the decryption key data from the saml2 core and xmlenc
//standards.
// EncryptedKey contains the decryption key data from the saml2 core and xmlenc
// standards.
type EncryptedKey struct {
// EncryptionMethod string `xml:"EncryptionMethod>Algorithm"`
X509Data string `xml:"KeyInfo>X509Data>X509Certificate"`
CipherValue string `xml:"CipherData>CipherValue"`
EncryptionMethod EncryptionMethod
}

//EncryptionMethod specifies the type of encryption that was used.
// EncryptionMethod specifies the type of encryption that was used.
type EncryptionMethod struct {
Algorithm string `xml:",attr,omitempty"`
//Digest method is present for algorithms like RSA-OAEP.
Expand All @@ -51,19 +52,19 @@ type EncryptionMethod struct {
DigestMethod *DigestMethod `xml:",omitempty"`
}

//DigestMethod is a digest type specification
// DigestMethod is a digest type specification
type DigestMethod struct {
Algorithm string `xml:",attr,omitempty"`
}

//Well-known public-key encryption methods
// Well-known public-key encryption methods
const (
MethodRSAOAEP = "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p"
MethodRSAOAEP2 = "http://www.w3.org/2009/xmlenc11#rsa-oaep"
MethodRSAv1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5"
)

//Well-known private key encryption methods
// Well-known private key encryption methods
const (
MethodAES128GCM = "http://www.w3.org/2009/xmlenc11#aes128-gcm"
MethodAES192GCM = "http://www.w3.org/2009/xmlenc11#aes192-gcm"
Expand All @@ -73,15 +74,15 @@ const (
MethodTripleDESCBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc"
)

//Well-known hash methods
// Well-known hash methods
const (
MethodSHA1 = "http://www.w3.org/2000/09/xmldsig#sha1"
MethodSHA256 = "http://www.w3.org/2000/09/xmldsig#sha256"
MethodSHA512 = "http://www.w3.org/2000/09/xmldsig#sha512"
)

//SHA-1 is commonly used for certificate fingerprints (openssl -fingerprint and ADFS thumbprint).
//SHA-1 is sufficient for our purposes here (error message).
// SHA-1 is commonly used for certificate fingerprints (openssl -fingerprint and ADFS thumbprint).
// SHA-1 is sufficient for our purposes here (error message).
func debugKeyFp(keyBytes []byte) string {
if len(keyBytes) < 1 {
return ""
Expand All @@ -100,8 +101,8 @@ func debugKeyFp(keyBytes []byte) string {
return ret
}

//DecryptSymmetricKey returns the private key contained in the EncryptedKey document
func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block, error) {
// DecryptSymmetricKey returns the private key contained in the EncryptedKey document
func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate, algo string) (cipher.Block, error) {
if len(cert.Certificate) < 1 {
return nil, fmt.Errorf("decryption tls.Certificate has no public certs attached")
}
Expand Down Expand Up @@ -146,6 +147,14 @@ func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block
}
}

cipher := func(k []byte) (cipher.Block, error) {
if algo == MethodTripleDESCBC {
return des.NewTripleDESCipher(k)
} else {
return aes.NewCipher(k)
}
}

switch ek.EncryptionMethod.Algorithm {
case "":
return nil, fmt.Errorf("missing encryption algorithm")
Expand All @@ -155,7 +164,7 @@ func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block
return nil, fmt.Errorf("rsa internal error: %v", err)
}

b, err := aes.NewCipher(pt)
b, err := cipher(pt)
if err != nil {
return nil, err
}
Expand All @@ -175,7 +184,7 @@ func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block
//The RSA v1.5 Key Transport algorithm given below are those used in conjunction with TRIPLEDES
//Please also see https://www.w3.org/TR/xmlenc-core/#sec-Algorithms and
//https://www.w3.org/TR/xmlenc-core/#rsav15note.
b, err := aes.NewCipher(pt)
b, err := cipher(pt)
if err != nil {
return nil, err
}
Expand Down