pyceo/ceod/model/LDAPService.py

393 lines
15 KiB
Python

import contextlib
from typing import Union, Dict, List
from flask import g
import gssapi
import ldap3
from zope import component
from zope.interface import implementer
from ceo_common.errors import UserNotFoundError, GroupNotFoundError, \
UserAlreadyExistsError, GroupAlreadyExistsError
from ceo_common.interfaces import ILDAPService, IConfig, \
IUser, IGroup, IUWLDAPService, IKerberosService
from ceo_common.model import Term
import ceo_common.utils as ceo_common_utils
from .User import User
from .Group import Group
@implementer(ILDAPService)
class LDAPService:
def __init__(self):
cfg = component.getUtility(IConfig)
self.ldap_server_url = cfg.get('ldap_server_url')
self.ldap_sasl_realm = cfg.get('ldap_sasl_realm')
self.ldap_users_base = cfg.get('ldap_users_base')
self.ldap_groups_base = cfg.get('ldap_groups_base')
self.ldap_sudo_base = cfg.get('ldap_sudo_base')
self.member_min_id = cfg.get('members_min_id')
self.member_max_id = cfg.get('members_max_id')
self.club_min_id = cfg.get('clubs_min_id')
self.club_max_id = cfg.get('clubs_max_id')
self.krb_srv = component.getUtility(IKerberosService)
def _get_ldap_conn(self) -> ldap3.Connection:
if 'ldap_conn' in g:
return g.ldap_conn
kwargs = {'auto_bind': True, 'raise_exceptions': True}
# Use GSSAPI authentication if creds are available
creds_token = None
if g.get('need_admin_creds', False):
creds_token = self.krb_srv.get_admin_creds_token()
elif 'client_token' in g:
creds_token = g.client_token
if creds_token is not None:
kwargs['authentication'] = ldap3.SASL
kwargs['sasl_mechanism'] = ldap3.KERBEROS
creds = gssapi.Credentials(token=creds_token)
# see https://github.com/cannatag/ldap3/blob/master/ldap3/protocol/sasl/kerberos.py
kwargs['sasl_credentials'] = (None, None, creds)
conn = ldap3.Connection(self.ldap_server_url, **kwargs)
# cache the connection for a single request
g.ldap_conn = conn
return conn
def _get_readable_entry_for_user(self, conn: ldap3.Connection, username: str) -> ldap3.Entry:
base = self.uid_to_dn(username)
try:
conn.search(
base, '(objectClass=*)', search_scope=ldap3.BASE,
attributes=ldap3.ALL_ATTRIBUTES)
except ldap3.core.exceptions.LDAPNoSuchObjectResult:
raise UserNotFoundError(username)
return conn.entries[0]
def _get_readable_entry_for_group(self, conn: ldap3.Connection, cn: str) -> ldap3.Entry:
base = self.group_cn_to_dn(cn)
try:
conn.search(
base, '(objectClass=*)', search_scope=ldap3.BASE,
attributes=ldap3.ALL_ATTRIBUTES)
except ldap3.core.exceptions.LDAPNoSuchObjectResult:
raise GroupNotFoundError(cn)
return conn.entries[0]
def _get_writable_entry_for_user(self, user: IUser) -> ldap3.WritableEntry:
if user.ldap3_entry is None:
conn = self._get_ldap_conn()
user.ldap3_entry = self._get_readable_entry_for_user(conn, user.uid)
return user.ldap3_entry.entry_writable()
def _get_writable_entry_for_group(self, group: IGroup) -> ldap3.WritableEntry:
if group.ldap3_entry is None:
conn = self._get_ldap_conn()
group.ldap3_entry = self._get_readable_entry_for_group(conn, group.cn)
return group.ldap3_entry.entry_writable()
def get_user(self, username: str) -> IUser:
conn = self._get_ldap_conn()
entry = self._get_readable_entry_for_user(conn, username)
return User.deserialize_from_ldap(entry)
def get_group(self, cn: str) -> IGroup:
conn = self._get_ldap_conn()
entry = self._get_readable_entry_for_group(conn, cn)
return Group.deserialize_from_ldap(entry)
def get_groups_for_user(self, username: str) -> List[str]:
conn = self._get_ldap_conn()
conn.search(self.ldap_groups_base,
f'(uniqueMember={self.uid_to_dn(username)})',
attributes=['cn'])
return sorted([entry.cn.value for entry in conn.entries])
def get_display_info_for_users(self, usernames: List[str]) -> List[Dict[str, str]]:
if not usernames:
return []
conn = self._get_ldap_conn()
filter = '(|' + ''.join([f'(uid={uid})' for uid in usernames]) + ')'
attributes = ['uid', 'cn', 'program']
conn.search(self.ldap_users_base, filter, attributes=attributes)
return sorted([
{
'uid': entry.uid.value,
'cn': entry.cn.value,
'program': entry.program.value or 'Unknown',
}
for entry in conn.entries
], key=lambda member: member['uid'])
def get_users_with_positions(self) -> List[IUser]:
conn = self._get_ldap_conn()
conn.search(self.ldap_users_base, '(position=*)', attributes=ldap3.ALL_ATTRIBUTES)
return [User.deserialize_from_ldap(entry) for entry in conn.entries]
def uid_to_dn(self, uid: str):
return f'uid={uid},{self.ldap_users_base}'
def group_cn_to_dn(self, cn: str):
return f'cn={cn},{self.ldap_groups_base}'
def _get_next_uid(self, conn: ldap3.Connection, min_id: int, max_id: int) -> int:
"""Gets the next available UID number between min_id and max_id, inclusive."""
def ldap_uid_or_gid_exists(uid: int) -> bool:
return conn.search(
self.ldap_users_base,
f'(|(uidNumber={uid})(gidNumber={uid}))',
size_limit=1)
while min_id != max_id:
cur_uid = (min_id + max_id) // 2
if ldap_uid_or_gid_exists(cur_uid):
min_id = cur_uid + 1
else:
max_id = cur_uid
if ldap_uid_or_gid_exists(min_id):
raise Exception('no UIDs remaining')
else:
return min_id
def add_sudo_role(self, uid: str):
conn = self._get_ldap_conn()
obj_def = ldap3.ObjectDef(['sudoRole'], conn)
writer = ldap3.Writer(conn, obj_def)
dn = f'cn=%{uid},{self.ldap_sudo_base}'
entry = writer.new(dn)
entry.cn = '%' + uid
entry.sudoUser = '%' + uid
entry.sudoHost = 'ALL'
entry.sudoCommand = 'ALL'
entry.sudoOption = ['!authenticate']
entry.sudoRunAsUser = uid
writer.commit()
def remove_sudo_role(self, uid: str):
conn = self._get_ldap_conn()
dn = f'cn=%{uid},{self.ldap_sudo_base}'
conn.delete(dn)
def add_user(self, user: IUser):
object_classes = ['top', 'account', 'posixAccount', 'shadowAccount']
if user.is_club():
min_id, max_id = self.club_min_id, self.club_max_id
object_classes.append('club')
else:
assert user.given_name and user.sn, \
'First name and last name must be specified for new members'
assert user.terms or user.non_member_terms, \
'terms and non_member_terms cannot both be empty'
min_id, max_id = self.member_min_id, self.member_max_id
object_classes.append('member')
if user.mail_local_addresses:
object_classes.append('inetLocalMailRecipient')
conn = self._get_ldap_conn()
obj_def = ldap3.ObjectDef(object_classes, conn)
uid_number = self._get_next_uid(conn, min_id, max_id)
user.uid_number = uid_number
user.gid_number = uid_number
writer = ldap3.Writer(conn, obj_def)
entry = writer.new(self.uid_to_dn(user.uid))
entry.cn = user.cn
entry.uidNumber = user.uid_number
entry.gidNumber = user.gid_number
entry.homeDirectory = user.home_directory
if user.login_shell:
entry.loginShell = user.login_shell
if user.program:
entry.program = user.program
if user.terms:
entry.term = user.terms
if user.non_member_terms:
entry.nonMemberTerm = user.non_member_terms
if user.positions:
entry.position = user.positions
if user.mail_local_addresses:
entry.mailLocalAddress = user.mail_local_addresses
if user.is_club_rep:
entry.isClubRep = True
if not user.is_club():
entry.givenName = user.given_name
entry.sn = user.sn
entry.userPassword = '{SASL}%s@%s' % (user.uid, self.ldap_sasl_realm)
try:
writer.commit()
except ldap3.core.exceptions.LDAPEntryAlreadyExistsResult:
raise UserAlreadyExistsError()
@contextlib.contextmanager
def entry_ctx_for_user(self, user: IUser):
entry = self._get_writable_entry_for_user(user)
yield entry
entry.entry_commit_changes()
def remove_user(self, user: IUser):
conn = self._get_ldap_conn()
conn.delete(self.uid_to_dn(user.uid))
def add_group(self, group: IGroup) -> IGroup:
conn = self._get_ldap_conn()
# make sure that the caller initialized the GID number
assert group.gid_number
obj_def = ldap3.ObjectDef(['group', 'posixGroup'], conn)
writer = ldap3.Writer(conn, obj_def)
entry = writer.new(self.group_cn_to_dn(group.cn))
entry.cn = group.cn
entry.gidNumber = group.gid_number
if group.members:
entry.uniqueMember = [self.uid_to_dn(uid) for uid in group.members]
if group.description:
entry.description = group.description
try:
writer.commit()
except ldap3.core.exceptions.LDAPEntryAlreadyExistsResult:
raise GroupAlreadyExistsError()
def get_nonflagged_expired_users(self) -> List[IUser]:
syscom_members = self.get_group('syscom').members
clauses = []
term = Term.current()
clauses.append(f'term={term}')
clauses.append(f'nonMemberTerm={term}')
# Include last term too if the new term just started
dt = ceo_common_utils.get_current_datetime()
if dt.month == term.start_month():
last_term = term - 1
clauses.append(f'term={last_term}')
clauses.append(f'nonMemberTerm={last_term}')
query = '(!(|(shadowExpire=1)(' + ')('.join(clauses) + ')))'
query = '(&' + query + '(objectClass=member))'
conn = self._get_ldap_conn()
conn.search(
self.ldap_users_base,
query,
attributes=ldap3.ALL_ATTRIBUTES,
search_scope=ldap3.LEVEL)
return [
User.deserialize_from_ldap(entry)
for entry in conn.entries
if entry.uid.value not in syscom_members
]
def get_expiring_users(self) -> List[IUser]:
term = Term.current()
dt = ceo_common_utils.get_current_datetime()
if dt.month != term.start_month():
# We only send membership renewal reminders at the
# start of a term
return []
last_term = term - 1
query = f'(&(term={last_term})(!(term={term})))'
conn = self._get_ldap_conn()
conn.search(
self.ldap_users_base,
query,
attributes=ldap3.ALL_ATTRIBUTES,
search_scope=ldap3.LEVEL)
return [
User.deserialize_from_ldap(entry)
for entry in conn.entries
]
@contextlib.contextmanager
def entry_ctx_for_group(self, group: IGroup):
entry = self._get_writable_entry_for_group(group)
yield entry
entry.entry_commit_changes()
def remove_group(self, group: IGroup):
conn = self._get_ldap_conn()
conn.delete(self.group_cn_to_dn(group.cn))
def update_programs(
self,
dry_run: bool = False,
members: Union[List[str], None] = None,
uwldap_batch_size: int = 100,
):
if members:
filter = '(|' + ''.join([f'(uid={uid})' for uid in members]) + ')'
else:
filter = '(objectClass=member)'
conn = self._get_ldap_conn()
conn.search(
self.ldap_users_base, filter, attributes=['uid', 'program'])
uids = [entry.uid.value for entry in conn.entries]
csc_programs = [entry.program.value for entry in conn.entries]
uwldap_srv = component.getUtility(IUWLDAPService)
uw_programs = []
# send queries in small batches so that we don't have an
# enormous filter in our query to UWLDAP
for i in range(0, len(csc_programs), uwldap_batch_size):
batch_uids = uids[i:i + uwldap_batch_size]
batch_uw_programs = uwldap_srv.get_programs_for_users(batch_uids)
uw_programs.extend(batch_uw_programs)
# uw_programs[i] will be None if the 'ou' attribute was not
# present in UWLDAP, or if no UWLDAP entry was found at all
for i, uw_program in enumerate(uw_programs):
if uw_program in (None, 'expired', 'orphaned'):
# If the UWLDAP record is orphaned, nonexistent, or missing
# data, assume that the member graduated
uw_programs[i] = 'Alumni'
users_to_change = [
(uids[i], csc_programs[i], uw_programs[i])
for i in range(len(uids))
if csc_programs[i] != uw_programs[i]
]
if dry_run:
return users_to_change
for uid, old_program, new_program in users_to_change:
changes = {'program': [(ldap3.MODIFY_REPLACE, [new_program])]}
conn.modify(self.uid_to_dn(uid), changes)
return users_to_change
def _get_club_uids(self, conn: ldap3.Connection) -> List[str]:
conn.search(self.ldap_users_base, '(objectClass=club)', attributes=['uid'])
return [entry.uid.value for entry in conn.entries]
def get_clubs(self) -> List[IGroup]:
batch_size = 100
conn = self._get_ldap_conn()
club_uids = self._get_club_uids(conn)
clubs = []
for i in range(0, len(club_uids), batch_size):
club_uids_slice = club_uids[i:i + batch_size]
filter = '(|' + ''.join([f'(cn={uid})' for uid in club_uids_slice]) + ')'
conn.search(self.ldap_groups_base, filter, attributes=ldap3.ALL_ATTRIBUTES)
for entry in conn.entries:
clubs.append(Group.deserialize_from_ldap(entry))
return clubs
def get_club_reps_non_member_terms(self, club_reps: List[str]) -> Dict[str, List[Term]]:
batch_size = 100
conn = self._get_ldap_conn()
club_reps_terms = {}
for i in range(0, len(club_reps), batch_size):
club_reps_slice = club_reps[i:i + batch_size]
filter = '(|' + ''.join([f'(uid={uid})' for uid in club_reps_slice]) + ')'
conn.search(self.ldap_users_base, filter, attributes=['uid', 'nonMemberTerm'])
for entry in conn.entries:
uid = entry.uid.value
if 'nonMemberTerm' in entry.entry_attributes:
non_member_terms = list(map(Term, entry.nonMemberTerm.values))
else:
non_member_terms = []
club_reps_terms[uid] = non_member_terms
return club_reps_terms