Removed hardcoding of URI.
[public/libpam-csc.git] / pam_csc.c
1 #define PAM_SM_ACCOUNT
2 #include <unistd.h>
3 #include <sys/types.h>
4 #include <sys/time.h>
5 #include <time.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <security/pam_modules.h>
10 #include <security/pam_appl.h>
11 #include <ldap.h>
12 #include <syslog.h>
13 #include <pwd.h>
14
15 #define PAM_CSC_LDAP_USER_BASE_DN       "ou=People,dc=csclub,dc=uwaterloo,dc=ca"
16 #define PAM_CSC_LDAP_GROUP_BASE_DN      "ou=Group,dc=csclub,dc=uwaterloo,dc=ca"
17 #define PAM_CSC_LDAP_TIMEOUT            5
18 #define PAM_CSC_ALLOWED_GROUPS          "cn=staff"
19 #define PAM_CSC_MINIMUM_UID             1000
20 #define PAM_CSC_EXPIRED_MSG \
21     "*****************************************************************************\n" \
22     "*                                                                           *\n" \
23     "*    Your account has expired - please contact the Computer Science Club    *\n" \
24     "*                                                                           *\n" \
25     "*****************************************************************************\n"
26
27 /*
28  * User terms are defined as (3 * year + term) where term is:
29  *   0 = Winter, 1 = Spring, 2 = Fall
30  * Term is a string in the form [f|w|s][year]
31  */
32
33 enum check_user_type_t
34 {
35     check_user_exists,
36     check_user_cur_term,
37     check_user_prev_term,
38     check_user_groups
39 };
40
41 #define HANDLE_WARN \
42 { \
43     syslog(LOG_AUTHPRIV | LOG_WARNING, "pam_csc generated a warning on line %d of %s\n", __LINE__, __FILE__); \
44     retval = PAM_SUCCESS; \
45     goto cleanup; \
46 }
47
48 #define WARN_ZERO(x) \
49     if( (x) == 0 ) HANDLE_WARN
50
51 #define WARN_NEG1(x) \
52     if( (x) == -1 ) HANDLE_WARN
53
54 #define WARN_PAM(x) \
55     if( (x) != PAM_SUCCESS ) HANDLE_WARN
56
57 #define WARN_LDAP(x) \
58     if( (x) != LDAP_SUCCESS ) HANDLE_WARN
59
60 char* escape_ldap_string(const char* src)
61 {
62     char *dst, *dstPtr;
63     int i;
64
65     if(!(dst = malloc(2 * strlen(src) + 1)))
66         return NULL;
67     dstPtr = dst;
68
69     for(i = 0; i < strlen(src); i++)
70     {
71         if(src[i] == '*' || src[i] == '(' || src[i] == ')' || src[i] == '\\')
72         {
73             dstPtr[0] = '\\';
74             dstPtr++;
75         }
76         dstPtr[0] = src[i];
77         dstPtr++;
78     }
79     dstPtr[0] = '\0';
80
81     return dst;
82 }
83
84 int check_user(const char* username, enum check_user_type_t checkType)
85 {
86     int retval = PAM_SUCCESS;
87     time_t curTime;
88     struct tm* localTime;
89     int longTerm, year, term;
90     LDAP* ld = NULL;
91     static const char termChars[] = {'w', 's', 'f'};
92     char* usernameEscaped = NULL;
93     char* filter = NULL;
94     char* attr[] = {"objectClass", NULL};
95     struct timeval timeout = {PAM_CSC_LDAP_TIMEOUT, 0};
96     LDAPMessage* res = NULL;
97     char* baseDN = NULL;
98
99     /* fail-safe for root */
100     if(strcmp(username, "root") == 0)
101     {
102         return PAM_SUCCESS;
103     }
104
105     /* connect and bind */
106     WARN_LDAP( ldap_create(&ld) )
107     WARN_NEG1( ldap_simple_bind(ld, NULL, NULL) )
108
109     WARN_ZERO( usernameEscaped = escape_ldap_string(username) );
110     switch(checkType)
111     {
112     case check_user_exists:
113
114         /* format filter */
115         WARN_ZERO( filter = malloc(50 + strlen(usernameEscaped)) )
116         sprintf(filter, "(uid=%s)", usernameEscaped);
117         baseDN = PAM_CSC_LDAP_USER_BASE_DN;
118         break;
119
120     case check_user_prev_term:
121     case check_user_cur_term:
122
123         /* get term info and compute current and previous term */
124         WARN_NEG1( curTime = time(NULL) )
125         WARN_ZERO( localTime = localtime(&curTime) )
126         longTerm = 3 * (1900 + localTime->tm_year) + (localTime->tm_mon / 4);
127         if(checkType == check_user_prev_term)
128             longTerm--;
129         term = termChars[longTerm % 3];
130         year = longTerm / 3;
131
132         /* format filter */
133         WARN_ZERO( filter = malloc(100 + strlen(usernameEscaped)) )
134         sprintf(filter, "(&(uid=%s)(|(&(objectClass=member)(term=%c%d))(!(objectClass=member))))", 
135             usernameEscaped, term, year);
136         baseDN = PAM_CSC_LDAP_USER_BASE_DN;
137         break;
138
139     case check_user_groups:
140
141         /* format filter */
142         WARN_ZERO( filter = malloc(50 + strlen(PAM_CSC_ALLOWED_GROUPS) + strlen(usernameEscaped)) )
143         sprintf(filter, "(&(objectClass=posixGroup)(%s)(memberUid=%s))", PAM_CSC_ALLOWED_GROUPS, usernameEscaped);
144         baseDN = PAM_CSC_LDAP_GROUP_BASE_DN;
145         break;
146     }
147
148     /* search */
149     WARN_LDAP( ldap_search_st(ld, baseDN, LDAP_SCOPE_SUBTREE, filter, attr, 1, &timeout, &res) )
150     if((term = ldap_count_entries(ld, res)) == 0)
151         retval = PAM_AUTH_ERR;
152
153 cleanup:
154
155     if(usernameEscaped) free(usernameEscaped);
156     if(res) ldap_msgfree(res);
157     if(filter) free(filter);
158     if(ld) ldap_unbind(ld);
159
160     return retval;
161 }
162
163 int print_pam_message(pam_handle_t* pamh, char* msg, int style)
164 {
165     int retval = PAM_SUCCESS;
166     struct pam_conv* pamConv;
167     struct pam_message pamMessage;
168     struct pam_message* pamMessages[1];
169     struct pam_response* pamResponse;
170
171     /* output message */
172     WARN_PAM( pam_get_item(pamh, PAM_CONV, (const void**)&pamConv) )
173     pamMessages[0] = &pamMessage;
174     pamMessage.msg_style = style;
175     pamMessage.msg = msg;
176     WARN_PAM( pamConv->conv(1, (const struct pam_message**)pamMessages, 
177         &pamResponse, pamConv->appdata_ptr) )
178
179 cleanup:
180
181     return retval;
182 }
183
184 PAM_EXTERN int pam_sm_acct_mgmt(pam_handle_t* pamh, int flags, int argc, const char* argv[])
185 {
186     const char* username;
187     struct passwd* pwd;
188
189     /* determine username */
190     if((pam_get_user(pamh, &username, NULL) != PAM_SUCCESS) || !username)
191     {
192         return PAM_USER_UNKNOWN;
193     }
194
195     /* check uid */
196     pwd = getpwnam(username);
197     if(pwd && pwd->pw_uid < PAM_CSC_MINIMUM_UID)
198     {
199         return PAM_SUCCESS;
200     }
201
202     /* check if user exists in ldap */
203     if(check_user(username, check_user_exists) == PAM_AUTH_ERR)
204     {
205         return PAM_SUCCESS;
206     }
207
208     /* check if user is registered for the current term */
209     if(check_user(username, check_user_cur_term) == PAM_SUCCESS)
210     {
211         return PAM_SUCCESS;
212     }
213
214     /* check if user is registered for the previous term */
215     if(check_user(username, check_user_prev_term) == PAM_SUCCESS)
216     {
217         /* show warning */
218         syslog(LOG_AUTHPRIV | LOG_NOTICE, "(pam_csc): %s was not registered for current term but was registered for previous term - permitting login\n", username);
219         print_pam_message(pamh, PAM_CSC_EXPIRED_MSG, PAM_TEXT_INFO);
220         return PAM_SUCCESS;
221     }
222
223     /* check if user is in allowed groups */
224     if(check_user(username, check_user_groups) == PAM_SUCCESS)
225     {
226         /* show warning */
227         print_pam_message(pamh, PAM_CSC_EXPIRED_MSG, PAM_TEXT_INFO);
228         syslog(LOG_AUTHPRIV | LOG_NOTICE, "(pam_csc): %s was not registered but was in allowed groups - permitting login\n", username);
229         return PAM_SUCCESS;
230     }
231
232     /* account has expired - show prompt */
233     print_pam_message(pamh, PAM_CSC_EXPIRED_MSG, PAM_ERROR_MSG);
234     syslog(LOG_AUTHPRIV | LOG_NOTICE, "(pam_csc): %s was not registered and was not in allowed groups - denying login\n", username);
235
236     return PAM_AUTH_ERR;
237 }