Source code for ldap_sync.sources.ldap

#  Copyright (c) 2022. The Pycroft Authors. See the AUTHORS file.
#  This file is part of the Pycroft project and licensed under the terms of
#  the Apache License, Version 2.0. See the LICENSE file for details
"""
ldap_sync.sources.ldap
~~~~~~~~~~~~~~~~~~~~~~
"""
import ssl
import typing

import ldap3

from .. import logger, conversion
from ..concepts.record import UserRecord, GroupRecord
from ..config import SyncConfig
from ..concepts.types import LdapRecord


[docs] def establish_and_return_ldap_connection(config: SyncConfig) -> ldap3.Connection: tls = None if config.ca_certs_file or config.ca_certs_data: tls = ldap3.Tls( ca_certs_file=config.ca_certs_file, ca_certs_data=config.ca_certs_data, validate=ssl.CERT_REQUIRED, ) server = ldap3.Server( host=config.host, port=config.port, use_ssl=config.use_ssl, tls=tls ) return ldap3.Connection( server, user=config.bind_dn, password=config.bind_pw, auto_bind=True )
[docs] def _fetch_ldap_entries( connection: ldap3.Connection, base_dn: str, search_filter: str | None = None, attributes: str | typing.Collection[str] = ldap3.ALL_ATTRIBUTES, ) -> list[LdapRecord]: success = connection.search( search_base=base_dn, search_filter=search_filter, attributes=attributes ) if not success: logger.warning("LDAP search not successful. Result: %s", connection.result) return [] return [r for r in connection.response if r["dn"] != base_dn]
[docs] def _fetch_ldap_users(connection: ldap3.Connection, base_dn: str) -> list[LdapRecord]: return _fetch_ldap_entries( connection, base_dn, search_filter="(objectclass=inetOrgPerson)", attributes=[ldap3.ALL_ATTRIBUTES, "pwdAccountLockedTime"], )
[docs] def fetch_ldap_users( connection: ldap3.Connection, base_dn: str ) -> typing.Iterator[UserRecord]: for r in _fetch_ldap_users(connection, base_dn): yield conversion.ldap_user_to_record(r)
[docs] def _fetch_ldap_groups(connection: ldap3.Connection, base_dn: str) -> list[LdapRecord]: return _fetch_ldap_entries( connection, base_dn, search_filter="(objectclass=groupOfMembers)" )
[docs] def fetch_ldap_groups( connection: ldap3.Connection, base_dn: str ) -> typing.Iterator[GroupRecord]: for r in _fetch_ldap_groups(connection, base_dn): yield conversion.ldap_group_to_record(r)
[docs] def _fetch_ldap_properties( connection: ldap3.Connection, base_dn: str ) -> list[LdapRecord]: return _fetch_ldap_entries( connection, base_dn, search_filter="(objectclass=groupOfMembers)" )
[docs] def fetch_ldap_properties( connection: ldap3.Connection, base_dn: str ) -> typing.Iterator[GroupRecord]: for r in _fetch_ldap_properties(connection, base_dn): yield conversion.ldap_group_to_record(r)
[docs] def fake_connection() -> ldap3.Connection: server = ldap3.Server("mocked") connection = ldap3.Connection(server, client_strategy=ldap3.MOCK_SYNC) connection.open() return connection