diff --git a/README.md b/README.md index 7b41767..78ceda7 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ sudo dnf install python2-devel python3-devel openldap-devel | base_ou | yes | | Base OU to search for user and group entries | | group_dns | yes | | Which groups user must be member of to be granted access (group names are considered case-insensitive) | | group_dns_check | no | `and` | What kind of check to perform when validating user group membership (`and` / `or`). When `and` behavior is used, user needs to be part of all the specified groups and when `or` behavior is used, user needs to be part of at least one or more of the specified groups. | -| host | yes | | Hostname of the LDAP server | +| host | yes | | Hostname of the LDAP server. Multiple comma-separated entries are allowed. | | port | yes | | Port of the LDAP server | | use_ssl | no | `false` | Use LDAPS to connect | | use_tls | no | `false` | Start TLS on LDAP to connect | diff --git a/st2auth_ldap/ldap_backend.py b/st2auth_ldap/ldap_backend.py index 9dc1e69..e40a671 100644 --- a/st2auth_ldap/ldap_backend.py +++ b/st2auth_ldap/ldap_backend.py @@ -279,7 +279,11 @@ def _init_connection(self): # Setup connection and options. protocol = 'ldaps' if self._use_ssl else 'ldap' - endpoint = '%s://%s:%d' % (protocol, self._host, int(self._port)) + hosts = self._host.split(',') + for i in range(len(hosts)): + hosts[i] = '%s://%s:%d' % (protocol, hosts[i], int(self._port)) + + endpoint = ','.join(hosts) connection = ldap.initialize(endpoint, trace_level=trace_level) connection.set_option(ldap.OPT_DEBUG_LEVEL, 255) connection.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3) diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py index 8d9e81e..9a3be49 100644 --- a/tests/unit/test_backend.py +++ b/tests/unit/test_backend.py @@ -25,6 +25,7 @@ LDAP_HOST = '127.0.0.1' +LDAP_MULTIPLE_HOSTS = '127.0.0.1,localhost' LDAPS_PORT = 636 LDAP_BIND_DN = 'cn=Administrator,cn=users,dc=stackstorm,dc=net' LDAP_BIND_PASSWORD = uuid.uuid4().hex @@ -114,6 +115,25 @@ def test_authenticate(self): authenticated = backend.authenticate(LDAP_USER_UID, LDAP_USER_PASSWD) self.assertTrue(authenticated) + @mock.patch.object( + ldap.ldapobject.SimpleLDAPObject, 'simple_bind_s', + mock.MagicMock(return_value=None)) + @mock.patch.object( + ldap.ldapobject.SimpleLDAPObject, 'search_s', + mock.MagicMock(side_effect=[LDAP_USER_SEARCH_RESULT, LDAP_GROUP_SEARCH_RESULT])) + def test_authenticate_with_multiple_ldap_hosts(self): + backend = ldap_backend.LDAPAuthenticationBackend( + LDAP_BIND_DN, + LDAP_BIND_PASSWORD, + LDAP_BASE_OU, + LDAP_GROUP_DNS, + LDAP_MULTIPLE_HOSTS, + id_attr=LDAP_ID_ATTR + ) + + authenticated = backend.authenticate(LDAP_USER_UID, LDAP_USER_PASSWD) + self.assertTrue(authenticated) + @mock.patch.object( ldap.ldapobject.SimpleLDAPObject, 'simple_bind_s', mock.MagicMock(return_value=None))