Fix gss error reporting bug
[mspang/pyceo.git] / src / gss.c
1 #include <string.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <grp.h>
5
6 #include "util.h"
7 #include "gss.h"
8 #include "net.h"
9 #include "strbuf.h"
10
11 static gss_cred_id_t my_creds = GSS_C_NO_CREDENTIAL;
12 static gss_ctx_id_t context_handle = GSS_C_NO_CONTEXT;
13 static gss_name_t peer_name = GSS_C_NO_NAME;
14 static gss_name_t imported_service = GSS_C_NO_NAME;
15 static char *peer_principal;
16 static char *peer_username;
17 static OM_uint32 ret_flags;
18 static int complete;
19 char service_name[128];
20
21 void free_gss(void) {
22     OM_uint32 maj_stat, min_stat;
23
24     if (peer_name) {
25         maj_stat = gss_release_name(&min_stat, &peer_name);
26         if (maj_stat != GSS_S_COMPLETE)
27             gss_fatal("gss_release_name", maj_stat, min_stat);
28     }
29
30     if (imported_service) {
31         maj_stat = gss_release_name(&min_stat, &imported_service);
32         if (maj_stat != GSS_S_COMPLETE)
33             gss_fatal("gss_release_name", maj_stat, min_stat);
34     }
35
36     if (context_handle) {
37         maj_stat = gss_delete_sec_context(&min_stat, &context_handle, GSS_C_NO_BUFFER);
38         if (maj_stat != GSS_S_COMPLETE)
39             gss_fatal("gss_delete_sec_context", maj_stat, min_stat);
40     }
41
42     if (my_creds) {
43         maj_stat = gss_release_cred(&min_stat, &my_creds);
44         if (maj_stat != GSS_S_COMPLETE)
45             gss_fatal("gss_release_creds", maj_stat, min_stat);
46     }
47
48     free(peer_principal);
49     free(peer_username);
50 }
51
52 static char *gssbuf2str(gss_buffer_t buf) {
53     char *msgstr = xmalloc(buf->length + 1);
54     memcpy(msgstr, buf->value, buf->length);
55     msgstr[buf->length] = '\0';
56     return msgstr;
57 }
58
59 static void display_status(char *prefix, OM_uint32 code, int type) {
60     OM_uint32 maj_stat, min_stat;
61     gss_buffer_desc msg;
62     OM_uint32 msg_ctx = 0;
63     char *msgstr;
64
65     maj_stat = gss_display_status(&min_stat, code, type, GSS_C_NULL_OID,
66                                   &msg_ctx, &msg);
67     msgstr = gssbuf2str(&msg);
68     logmsg(LOG_ERR, "%s: %s", prefix, msgstr);
69     gss_release_buffer(&min_stat, &msg);
70     free(msgstr);
71
72     while (msg_ctx) {
73         maj_stat = gss_display_status(&min_stat, code, type, GSS_C_NULL_OID,
74                                       &msg_ctx, &msg);
75         msgstr = gssbuf2str(&msg);
76         logmsg(LOG_ERR, "additional: %s", msgstr);
77         gss_release_buffer(&min_stat, &msg);
78         free(msgstr);
79     }
80 }
81
82 void gss_fatal(char *msg, OM_uint32 maj_stat, OM_uint32 min_stat) {
83     logmsg(LOG_ERR, "fatal: %s", msg);
84     display_status("major", maj_stat, GSS_C_GSS_CODE);
85     display_status("minor", min_stat, GSS_C_MECH_CODE);
86     exit(1);
87 }
88
89 static void import_service(const char *service, const char *hostname) {
90     OM_uint32 maj_stat, min_stat;
91     gss_buffer_desc buf_desc;
92
93     if (snprintf(service_name, sizeof(service_name),
94                  "%s@%s", service, hostname) >= sizeof(service_name))
95         fatal("service name too long");
96
97     buf_desc.value = service_name;
98     buf_desc.length = strlen(service_name);
99
100     maj_stat = gss_import_name(&min_stat, &buf_desc,
101                                GSS_C_NT_HOSTBASED_SERVICE, &imported_service);
102     if (maj_stat != GSS_S_COMPLETE)
103         gss_fatal("gss_import_name", maj_stat, min_stat);
104 }
105
106 static void check_services(OM_uint32 flags) {
107     debug("gss services: %sconf %sinteg %smutual %sreplay %ssequence",
108             flags & GSS_C_CONF_FLAG     ? "+" : "-",
109             flags & GSS_C_INTEG_FLAG    ? "+" : "-",
110             flags & GSS_C_MUTUAL_FLAG   ? "+" : "-",
111             flags & GSS_C_REPLAY_FLAG   ? "+" : "-",
112             flags & GSS_C_SEQUENCE_FLAG ? "+" : "-");
113     if (~flags & GSS_C_CONF_FLAG)
114         fatal("confidentiality service required");
115     if (~flags & GSS_C_INTEG_FLAG)
116         fatal("integrity service required");
117     if (~flags & GSS_C_MUTUAL_FLAG)
118         fatal("mutual authentication required");
119 }
120
121 void server_acquire_creds(const char *service) {
122     OM_uint32 maj_stat, min_stat;
123     OM_uint32 time_rec;
124
125     if (!strlen(fqdn.buf))
126         fatal("empty fqdn");
127
128     import_service(service, fqdn.buf);
129
130     notice("acquiring credentials for %s", service_name);
131
132     maj_stat = gss_acquire_cred(&min_stat, imported_service, GSS_C_INDEFINITE,
133                                 GSS_C_NULL_OID_SET, GSS_C_ACCEPT, &my_creds,
134                                 NULL, &time_rec);
135     if (maj_stat != GSS_S_COMPLETE)
136         gss_fatal("gss_acquire_cred", maj_stat, min_stat);
137
138     if (time_rec != GSS_C_INDEFINITE)
139         fatal("credentials valid for %d seconds (oops)", time_rec);
140 }
141
142 void client_acquire_creds(const char *service, const char *hostname) {
143     import_service(service, hostname);
144 }
145
146 static char *princ_to_username(char *princ) {
147     char *ret = xstrdup(princ);
148     char *c = strchr(ret, '@');
149     if (c)
150         *c = '\0';
151     return ret;
152 }
153
154 int process_server_token(gss_buffer_t incoming_tok, gss_buffer_t outgoing_tok) {
155     OM_uint32 maj_stat, min_stat;
156     OM_uint32 time_rec;
157     gss_OID name_type;
158     gss_buffer_desc peer_princ;
159
160     if (complete)
161         fatal("unexpected %zd-byte token from peer", incoming_tok->length);
162
163     maj_stat = gss_accept_sec_context(&min_stat, &context_handle, my_creds,
164             incoming_tok, GSS_C_NO_CHANNEL_BINDINGS, &peer_name, NULL,
165             outgoing_tok, &ret_flags, &time_rec, NULL);
166     if (maj_stat == GSS_S_COMPLETE) {
167         check_services(ret_flags);
168
169         complete = 1;
170
171         maj_stat = gss_display_name(&min_stat, peer_name, &peer_princ, &name_type);
172         if (maj_stat != GSS_S_COMPLETE)
173             gss_fatal("gss_display_name", maj_stat, min_stat);
174
175         peer_principal = xstrdup((char *)peer_princ.value);
176         peer_username = princ_to_username((char *)peer_princ.value);
177
178         notice("client authenticated as %s", peer_principal);
179         debug("context expires in %d seconds", time_rec);
180
181         maj_stat = gss_release_buffer(&min_stat, &peer_princ);
182         if (maj_stat != GSS_S_COMPLETE)
183             gss_fatal("gss_release_buffer", maj_stat, min_stat);
184
185     } else if (maj_stat != GSS_S_CONTINUE_NEEDED) {
186         gss_fatal("gss_accept_sec_context", maj_stat, min_stat);
187     }
188
189     return complete;
190 }
191
192 int process_client_token(gss_buffer_t incoming_tok, gss_buffer_t outgoing_tok) {
193     OM_uint32 maj_stat, min_stat;
194     OM_uint32 time_rec;
195     gss_OID_desc krb5 = *gss_mech_krb5;
196
197     if (complete)
198         fatal("unexpected token from peer");
199
200     maj_stat = gss_init_sec_context(&min_stat, GSS_C_NO_CREDENTIAL, &context_handle,
201                                     imported_service, &krb5, GSS_C_MUTUAL_FLAG |
202                                     GSS_C_REPLAY_FLAG | GSS_C_SEQUENCE_FLAG,
203                                     GSS_C_INDEFINITE, GSS_C_NO_CHANNEL_BINDINGS,
204                                     incoming_tok, NULL, outgoing_tok, &ret_flags,
205                                     &time_rec);
206     if (maj_stat == GSS_S_COMPLETE) {
207         notice("server authenticated as %s", service_name);
208         notice("context expires in %d seconds", time_rec);
209
210         check_services(ret_flags);
211
212         complete = 1;
213
214     } else if (maj_stat != GSS_S_CONTINUE_NEEDED) {
215         gss_fatal("gss_init_sec_context", maj_stat, min_stat);
216     }
217
218     return complete;
219 }
220
221 int initial_client_token(gss_buffer_t outgoing_tok) {
222     return process_client_token(GSS_C_NO_BUFFER, outgoing_tok);
223 }
224
225 char *client_principal(void) {
226     return peer_principal;
227 }
228
229 char *client_username(void) {
230     return peer_username;
231 }
232