Python library for interacting with LDAP
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

581 lines
23 KiB

import base64
import re
import ldap
import TSIGKey
from exceptions import *
from enum import Enum
from ldap import modlist, filter
from colorlog import ColoredFormatter
import logging
from pprint import pprint
def _get_readable_ldap_scope(scope: int) -> str:
if scope == ldap.SCOPE_SUBORDINATE:
return 'subordinate'
elif scope == ldap.SCOPE_SUBTREE:
return 'subordinate'
elif scope == ldap.SCOPE_ONELEVEL:
return 'subordinate'
return 'base'
class KeyType(Enum):
Server = 'server'
User = 'user'
Zone = 'zone'
class PlabsDNS:
def __init__(self, config: dict, logger: logging.Logger = None, ldap_connection: ldap.ldapobject.LDAPObject = None):
self.base_dn = config['ldap']['base_dn']
self.base_user_dn = config['ldap']['base_user_dn']
self.base_zone_dn = self._build_dn('ou=zones')
self.base_server_dn = self._build_dn('ou=servers')
self.logger = logger if logger else self._setup_logging(logging.DEBUG)
self.ldap_connection = ldap_connection if ldap_connection else \
self._ldap_connect(config['ldap']['uri'], config['ldap']['bind_dn'], config['ldap']['bind_pw'])
self.domain_regex = re.compile('^((?!-)[A-Za-z0-9-]{1,63}(?<!-)\.)*[A-Za-z]{2,6}$')
self.user_regex = re.compile('^[\w\d]+$')
self.dn_regex = re.compile('^(?:[\w]+=[\w\d.]+,)*[\w]+=[\w\d]+$')
self.ipv4_regex = re.compile('^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$')
self.ipv6_regex = re.compile('(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))')
def _setup_logging(self, log_level=logging.WARNING):
log_format = '%(log_color)s[%(name)s %(levelname)s]: %(message)s'
logger = logging.getLogger('plabsDNS')
logger.setLevel(log_level)
stdout_formatter = ColoredFormatter(log_format)
stdout_handler = logging.StreamHandler()
stdout_handler.setFormatter(stdout_formatter)
logger.addHandler(stdout_handler)
return logger
def _log_ldap_search(self, dn, scope, filter_string = None, attrib_list = None):
log_line = 'LDAP {} search on \'{}\''.format(_get_readable_ldap_scope(scope), dn)
if filter_string:
log_line += ' with filter \'{}\''.format(filter_string)
if attrib_list:
log_line += ' for attributes ['
for attrib in attrib_list:
log_line += attrib + ', '
log_line = log_line[:-2] + ']'
self.logger.debug(log_line)
def _log_ldap_add(self, dn, attrib_list):
log_line = 'LDAP adding dn \'{}\' with attributes ['.format(dn)
for k, v in attrib_list:
if type(v) == list:
log_line += k + '=(' + ','.join([x.decode('utf-8') for x in v]) + '), '
else:
log_line += k + '=' + v.decode('utf-8') + ', '
self.logger.debug(log_line[:-2] + ']')
def _log_ldap_modify(self, dn, attrib_list):
def get_action(a):
if a == ldap.MOD_REPLACE:
return 'REPLACE'
elif a == ldap.MOD_DELETE:
return 'DELETE'
elif a == ldap.MOD_ADD:
return 'ADD'
return 'UNKNOWN'
log_line = 'LDAP modifying dn \'{}\' with attributes ['.format(dn)
for a, k, v in attrib_list:
log_line += '{}: '.format(get_action(a)) + k + '=' + v.decode('utf-8') + ', '
self.logger.debug(log_line[:-2] + ']')
def _log_ldap_delete(self, dn):
self.logger.debug('LDAP deleting dn \'{}\''.format(dn))
def _ldap_connect(self, uri: str, bind_dn: str, bind_pw: str):
connection = ldap.initialize(uri)
connection.set_option(ldap.OPT_REFERRALS, 0)
# connection.start_tls_s()
connection.simple_bind_s(bind_dn, bind_pw)
self.logger.debug('LDAP connection to {} successful'.format(uri))
return connection
def _build_dn(self, *dcs: str) -> str:
dn = ''
for dc in dcs:
dn += dc + ','
dn = dn[:-1]
if dn.endswith(self.base_dn):
return dn
return dn + ',' + self.base_dn
def _dn_to_domain(self, dn: str) -> str:
if not self.dn_regex.search(dn):
raise InvalidDistinguishedName(dn)
if not dn.endswith(self.base_zone_dn):
raise PlabsException('Invalid DN for zone')
domain_in_dcs = dn[:-len(self.base_zone_dn) - 1]
domain = ''
for part in domain_in_dcs.split(','):
domain += part[3:] + '.' # remove dc=
return domain[:-1]
def _domain_to_dn(self, domain: str):
if not self.domain_regex.search(domain):
raise InvalidDomainName(domain)
dn = ''
for dc in domain.split('.'):
dn += 'dc=' + dc + ','
return self._build_dn(dn[:-1], self.base_zone_dn)
def _get_ldap_attr(self, obj: list, attrib: str):
if attrib == 'dn':
return obj[0]
if attrib in obj[1]:
arr = obj[1][attrib]
if type(arr) == list:
if len(arr) > 1:
return [x.decode('utf-8') for x in arr]
elif len(arr) == 1:
return arr[0].decode('utf-8')
return None
return arr[0].decode('utf-8')
return None
def ldap_server_lookup(self, identifier: str, attrlist=()):
attrlist = tuple(set(('objectClass', 'tsigKey', 'cn') + attrlist))
if self.dn_regex.search(identifier) and identifier.endswith(self.base_dn):
try:
self._log_ldap_search(identifier, ldap.SCOPE_BASE, None, attrlist)
ldap_result = self.ldap_connection.search_s(identifier, ldap.SCOPE_BASE, None, attrlist)
except ldap.NO_SUCH_OBJECT:
raise ServerNotFound
elif self.domain_regex.search(identifier):
ldap_filter = '(cn=%s)'
try:
formatted_filter = ldap.filter.filter_format(ldap_filter, [identifier])
self._log_ldap_search(self.base_server_dn, ldap.SCOPE_SUBORDINATE, formatted_filter, attrlist)
ldap_result = self.ldap_connection.search_s(
self.base_dn,
ldap.SCOPE_SUBORDINATE,
formatted_filter,
attrlist
)
except ldap.NO_SUCH_OBJECT:
raise ServerNotFound
else:
raise InvalidIdentifier
if len(ldap_result) == 0:
raise UserNotFound('Server not found: \'{}\''.format(identifier))
if not len(ldap_result) == 1:
raise PlabsException('Server identifier not unique. {} DNs found'.format(len(ldap_result)))
return ldap_result[0]
def ldap_user_lookup(self, identifier: str, attrlist=()):
attrlist = tuple(set(('objectClass', 'tsigKey', 'uid') + attrlist))
if self.dn_regex.search(identifier) and identifier.endswith(self.base_user_dn):
try:
self._log_ldap_search(identifier, ldap.SCOPE_BASE, None, attrlist)
ldap_result = self.ldap_connection.search_s(identifier, ldap.SCOPE_BASE, None, attrlist)
except ldap.NO_SUCH_OBJECT:
raise UserNotFound
elif self.user_regex.search(identifier):
ldap_filter = '(|(uid=%s)(uidNumber=%s))'
formatted_filter = ldap.filter.filter_format(ldap_filter, [identifier, identifier])
self._log_ldap_search(self.base_user_dn, ldap.SCOPE_SUBORDINATE, formatted_filter, attrlist)
try:
ldap_result = self.ldap_connection.search_s(
self.base_user_dn,
ldap.SCOPE_SUBORDINATE,
formatted_filter,
attrlist
)
except ldap.NO_SUCH_OBJECT:
raise UserNotFound
else:
raise InvalidIdentifier
if len(ldap_result) == 0:
raise UserNotFound('User not found: \'{}\''.format(identifier))
elif not len(ldap_result) == 1:
raise PlabsException('No unique result. {} DNs found'.format(len(ldap_result)))
return ldap_result[0]
def ldap_user_dns_lookup(self, identifier: str, attrlist=()):
attrlist = tuple(set(('objectClass', 'tsigKey', 'uid') + attrlist))
user = self.ldap_user_lookup(identifier, attrlist)
if 'plabsDnsUser' not in self._get_ldap_attr(user, 'objectClass'):
raise PlabsException('User {} ({}) is not a DNS user'.format(self._get_ldap_attr(user, 'uid'), user[0]))
elif not self._get_ldap_attr(user, 'tsigKey'):
raise PlabsException('User {} ({}) has no TSIG Key'.format(self._get_ldap_attr(user, 'uid'), user[0]))
return user
def ldap_user_add_dns_attrs(self, identifier: str, tsig_key: str = None):
if tsig_key and not self.verify_tsig_key(KeyType.User, tsig_key):
raise InvalidTsigKey
user = self.ldap_user_lookup(identifier)
if 'plabsDnsUser' in self._get_ldap_attr(user, 'objectClass'):
raise PlabsException('User {} ({}) already has DNS attributes'.format(self._get_ldap_attr(user, 'uid'), user[0]))
if not tsig_key:
tsig_key = self.generate_tsig_key(KeyType.User, self._get_ldap_attr(user, 'uid'))
# Add attributes
mod_list = [(ldap.MOD_ADD, 'objectClass', b'plabsDnsUser'),
(ldap.MOD_ADD, 'tsigKey', tsig_key.encode('utf-8'))]
self._log_ldap_modify(user[0], mod_list)
try:
self.ldap_connection.modify_s(user[0], mod_list)
except ldap.LDAPError:
raise PlabsException('Modify of dn {} failed'.format(user[0]))
return {
'dn': user[0],
'tsigKey': tsig_key
}
def ldap_user_remove_dns_attrs(self, identifier: str):
user = self.ldap_user_lookup(identifier)
if 'plabsDnsUser' not in self._get_ldap_attr(user, 'objectClass'):
raise PlabsException('User {} ({}) does not have DNS attributes'.format(self._get_ldap_attr(user, 'uid'), user[0]))
filter_str = ldap.filter.filter_format('(zoneAdmin=%s)', (user[0],))
self._log_ldap_search(self.base_zone_dn, ldap.SCOPE_SUBORDINATE, filter_str, ('dn',))
zone_admin_of = self.ldap_connection.search_s(self.base_zone_dn, ldap.SCOPE_SUBORDINATE, filter_str, ('dn',))
zone_admin_of = [self._dn_to_domain(x[0]) for x in zone_admin_of]
if len(zone_admin_of) > 0:
raise PlabsException('User is still zone admin for following zones: ' + ', '.join(zone_admin_of))
mod_list = [(ldap.MOD_DELETE, 'objectClass', b'plabsDnsUser'),
(ldap.MOD_DELETE, 'tsigKey', self._get_ldap_attr(user, 'tsigKey').encode('utf-8'))]
self._log_ldap_modify(user[0], mod_list)
try:
self.ldap_connection.modify_s(user[0], mod_list)
except ldap.LDAPError:
raise PlabsException('Modify of dn {} failed'.format(user[0]))
def ldap_zone_lookup(self, identifier: str, attrlist=()):
attrlist = tuple(set(('dn', 'primaryMaster', 'tsigKey', 'zoneAdmin') + attrlist))
if self.dn_regex.search(identifier) and identifier.endswith(self.base_dn):
dn = identifier
elif self.domain_regex.search(identifier):
dn = self._domain_to_dn(identifier)
else:
raise InvalidIdentifier(identifier)
self._log_ldap_search(dn, ldap.SCOPE_BASE, None, attrlist)
try:
ldap_result = self.ldap_connection.search_s(dn, ldap.SCOPE_BASE, None, attrlist)
except ldap.NO_SUCH_OBJECT:
raise ZoneNotFound(identifier)
# Should not be possible to reach this but you never know what I missed
if not len(ldap_result) == 1:
raise PlabsException('Zone identifier not unique. {} DNs found'.format(len(ldap_result)))
return ldap_result[0]
def ldap_zone_add_ou(self, top_levels: str):
if not self.domain_regex.search(top_levels):
raise InvalidDomainName
parts = top_levels.split('.')
dn = ''
for part in parts:
dn += 'dc=' + part + ','
dn = dn[:-1]
mod_list = ldap.modlist.addModlist({
'objectClass': [b'dcObject', b'organizationalUnit'],
'dc': parts[0].encode('utf-8'),
'ou': ('domain-' + parts[0]).encode('utf-8')
})
dn = self._build_dn(dn, self.base_zone_dn)
self._log_ldap_add(dn, mod_list)
self.ldap_connection.add_s(dn, mod_list)
def ldap_zone_add(self, domain: str, zone_admin_identifier: str, tsig_key: str, master_identifier: str):
if not self.domain_regex.search(domain):
raise InvalidDomainName
parts = domain.split('.')
dn = ''
for part in parts:
dn += 'dc=' + part + ','
dn = dn[:-1]
master_dn = self.ldap_server_lookup(master_identifier, ('dn',))[0]
admin_dn = self.ldap_user_lookup(zone_admin_identifier, ('dn',))[0]
mod_list = ldap.modlist.addModlist({
'objectClass': [b'dcObject', b'plabsDnsZone'],
'dc': parts[0].encode('utf-8'),
'tsigKey': tsig_key.encode('utf-8'),
'zoneAdmin': admin_dn.encode('utf-8'),
'primaryMaster': master_dn.encode('utf-8')
})
dn = self._build_dn(dn, self.base_zone_dn)
self._log_ldap_add(dn, mod_list)
self.ldap_connection.add_s(dn, mod_list)
return {
'dn': dn,
'domain': domain,
'tsigKey': tsig_key,
'zoneAdmin': admin_dn,
'primaryMaster': master_dn
}
def add_zone(self, domain: str, zone_admin: str, primary_master: str, tsig_key: str = None):
if tsig_key and not self.verify_tsig_key(KeyType.Zone, tsig_key):
raise TSIGKey.InvalidKey
if not self.domain_regex.search(domain):
raise InvalidDomainName
dcs = domain.split('.')[1:]
dcs.reverse()
dc_present = True
for idx, dc in enumerate(dcs):
tld_part = dc
for i in range(0, idx):
tld_part += '.' + dcs[i]
if dc_present:
try:
self.ldap_zone_lookup(tld_part)
except ZoneNotFound:
dc_present = False
if not dc_present:
self.ldap_zone_add_ou(tld_part)
if not tsig_key:
tsig_key = self.generate_tsig_key(KeyType.Zone, domain)
try:
return self.ldap_zone_add(domain, zone_admin, tsig_key, primary_master)
except PlabsException:
return False
except ldap.LDAPError:
return False
def update_zone_attrib(self, identifier: str, kv: dict):
for k in kv.keys():
if k not in ('zoneAdmin', 'primaryMaster'):
raise PlabsException('Invalid attribute name \'{}\''.format(k))
zone = self.ldap_zone_lookup(identifier)
mod_list = []
if 'zoneAdmin' in kv.keys():
mod_list.append((ldap.MOD_REPLACE, 'zoneAdmin',
self.ldap_user_lookup(kv['zoneAdmin'])[0].encode('utf-8')))
if 'primaryMaster' in kv.keys():
mod_list.append((ldap.MOD_REPLACE, 'primaryMaster',
self.ldap_server_lookup(kv['primaryMaster'])[0].encode('utf-8')))
self._log_ldap_modify(zone[0], mod_list)
self.ldap_connection.modify_s(zone[0], mod_list)
def get_zones(self, primary_master_identifier: str = None):
zones = []
filterstr = None
if primary_master_identifier:
if self.domain_regex.search(primary_master_identifier):
master = 'cn={},{}'.format(primary_master_identifier, self.base_server_dn)
elif self.dn_regex.search(primary_master_identifier) \
and primary_master_identifier.startswith('cn=') \
and primary_master_identifier.endswith(self.base_server_dn):
master = primary_master_identifier
else:
raise InvalidIdentifier(primary_master_identifier)
ldap_filter = '(primaryMaster=%s)'
filterstr = ldap.filter.filter_format(ldap_filter, [master])
attribs = ('primaryMaster',)
self._log_ldap_search(self.base_user_dn, ldap.SCOPE_SUBORDINATE, filterstr, attribs)
ldap_result = self.ldap_connection.search_s(self.base_zone_dn, ldap.SCOPE_SUBORDINATE, filterstr, attribs)
for zone in ldap_result:
if self._get_ldap_attr(zone, 'primaryMaster'):
zones.append(self._dn_to_domain(zone[0]))
return zones
def delete_zone(self, identifier: str):
zone = self.ldap_zone_lookup(identifier)
self._log_ldap_delete(zone[0])
self.ldap_connection.delete_s(zone[0])
def add_server(self, cn: str, server_address: str, trusted: bool = False, tsig_key: str = None):
if not self.domain_regex.search(cn):
raise InvalidDomainName
if not tsig_key:
tsig_key = self.generate_tsig_key(KeyType.Server, cn)
elif not self.verify_tsig_key(KeyType.Server, tsig_key):
raise InvalidTsigKey
mod_list = ldap.modlist.addModlist(
{'objectClass': b'plabsDnsServer',
'cn': cn.encode('utf-8'),
'tsigKey': tsig_key.encode('utf-8'),
'serverAddress': server_address.encode('utf-8'),
'trusted': b'TRUE' if trusted else b'FALSE'}
)
dn = self._build_dn('cn=' + cn, self.base_server_dn)
self._log_ldap_add(dn, mod_list)
if self.ldap_connection.add_s(dn, mod_list):
return {
'dn': dn,
'cn': cn,
'tsigKey': tsig_key
}
return False
def update_server_attrib(self, identifier: str, kv: dict):
for k in kv.keys():
if k not in ('trusted', 'serverAddress'):
raise PlabsException('Invalid attribute name \'{}\''.format(k))
server = self.ldap_server_lookup(identifier)
mod_list = []
if 'trusted' in kv.keys():
trusted = type(kv['trusted']) == bool and kv['trusted']
trusted |= type(kv['trusted']) == str and kv['trusted'].lower() == 'true'
mod_list.append((ldap.MOD_REPLACE, 'trusted', b'TRUE' if trusted else b'FALSE'))
if 'serverAddress' in kv.keys():
if not (self.ipv4_regex.search(kv['serverAddress']) or self.ipv6_regex.search(kv['serverAddress'])):
raise PlabsException('Not a valid IP address')
mod_list.append((ldap.MOD_REPLACE, 'serverAddress', kv['serverAddress'].encode('utf-8')))
self._log_ldap_modify(server[0], mod_list)
self.ldap_connection.modify_s(server[0], mod_list)
def get_servers(self, append_domains: bool = False):
servers = {} if append_domains else []
attrlist = ('cn',)
self._log_ldap_search(self.base_server_dn, ldap.SCOPE_SUBORDINATE, None, attrlist)
ldap_result = self.ldap_connection.search_s(self.base_server_dn, ldap.SCOPE_SUBORDINATE, None, attrlist)
if append_domains:
for server in ldap_result:
servers[(self._get_ldap_attr(server, 'cn'))] = self.get_zones(self._get_ldap_attr(server, 'cn'))
else:
for server in ldap_result:
servers.append(self._get_ldap_attr(server, 'cn'))
return servers
def delete_server(self, identifier: str):
server = self.ldap_server_lookup(identifier)
zones = self.get_zones(server[0])
if len(zones) > 0:
raise PlabsException('There are still following domains managed by this server: ' + ', '.join(zones))
self._log_ldap_delete(server[0])
self.ldap_connection.delete_s(server[0])
def update_tsig_key(self, key_type: KeyType, identifier: str, tsig_key: str = None):
if tsig_key and not self.verify_tsig_key(key_type, tsig_key):
raise TSIGKey.InvalidKey
dn = None
key_name = None
if key_type == KeyType.User:
user = self.ldap_user_dns_lookup(identifier)
dn = user[0]
key_name = self._get_ldap_attr(user, 'uid')
elif key_type == KeyType.Zone:
zone = self.ldap_zone_lookup(identifier)
dn = zone[0]
key_name = self._dn_to_domain(zone[0])
elif key_type == KeyType.Server:
server = self.ldap_server_lookup(identifier)
dn = server[0]
key_name = self._get_ldap_attr(server, 'cn')
if not tsig_key:
tsig_key = self.generate_tsig_key(key_type, key_name)
mod_list = [
(ldap.MOD_REPLACE, 'tsigKey', tsig_key.encode('utf-8'))
]
self._log_ldap_modify(dn, mod_list)
self.ldap_connection.modify_s(dn, mod_list)
return {
'dn': dn,
'identifier': identifier,
'tsigKey': tsig_key
}
def generate_tsig_key(self, key_type: KeyType, name: str, algorithm=TSIGKey.HMAC_SHA256):
if key_type not in KeyType:
raise TSIGKey.UnknownAlgorithm
return TSIGKey.generate_tsig_key('{}.{}'.format(name, key_type.value), algorithm)
def verify_tsig_key(self, key_type: KeyType, key: str):
parts = key.split(':')
if not len(parts) == 3:
raise TSIGKey.InvalidKey
algorithm, name, secret = parts
ktlen = len(key_type.value)
identifier = name[0:-(ktlen + 1)]
if not name[-ktlen:] == key_type.value:
print(name[-ktlen:], key_type.value)
raise InvalidIdentifier('Requestd key type does not match provided key')
if key_type == KeyType.Server or key_type == KeyType.Zone:
if not self.domain_regex.search(identifier):
raise InvalidDomainName
elif key_type == KeyType.User:
if not self.user_regex.search(identifier):
raise InvalidIdentifier
secret_length = len(base64.b64decode(secret.encode('utf-8'))) * 8
return secret_length == TSIGKey.get_key_length(algorithm)