New release (version 0.2).
[public/pyceo-broken.git] / pylib / csc / backends / db.py
1 """
2 Database Backend Interface
3
4 This module is intended to be a thin wrapper around CEO database operations.
5 Methods on the connection class correspond in a straightforward way to SQL
6 queries. These methods may restructure and clean up query output but may make
7 no other assumptions about its content or purpose.
8
9 This module makes use of the PyGreSQL Python bindings to libpq,
10 PostgreSQL's native C client library.
11 """
12 import pgdb
13
14
15 class DBException(Exception):
16     """Exception class for database-related errors."""
17     pass
18     
19     
20 class DBConnection(object):
21     """
22     A connection to CEO's backend database. All database queries
23     and updates are made via this class.
24     
25     Exceptions: (all methods)
26         DBException - on database query failure
27
28     Note: Updates will never take place until commit() is called.
29
30     Note: In the event that any method of this class raises a
31           DBException and that exception is caught, rollback()
32           must be called before further queries will be successful.
33     
34     Example:
35         connection = DBConnection()
36         connection.connect("localhost", "ceo")
37         
38         # make queries and updates, i.e.
39         connection.insert_member("Calum T. Dalek")
40         
41         connection.commit()
42         connection.disconnect()
43     """
44
45     def __init__(self):
46         self.cnx = None
47         self.cursor = None
48
49           
50     def connect(self, hostname=None, database=None, username=None, password=None):
51         """
52         Establishes the connection to CEO's PostgreSQL database.
53         
54         Parameters:
55             hostname - hostname:port to connect to
56             database - name of database
57             username - user to authenticate as
58             password - password of username
59         """
60
61         if self.cnx: raise DBException("unable to connect: already connected")
62         
63         try:
64             self.cnx = pgdb.connect(host=hostname, database=database,
65                     user=username, password=password)
66             self.cursor = self.cnx.cursor()
67         except pgdb.Error, e:
68             raise DBException("unable to connect: %s" % e)
69
70
71     def disconnect(self):
72         """Closes the connection to CEO's PostgreSQL database."""
73
74         if self.cursor:
75             self.cursor.close()
76             self.cursor = None
77
78         if self.cnx:
79             self.cnx.close()
80             self.cnx = None
81
82     
83     def connected(self):
84         """Determine whether the connection has been established."""
85
86         return self.cnx is not None
87
88
89     def commit(self):
90         """Commits the current transaction and starts a new one."""
91
92         self.cnx.commit()
93
94
95     def rollback(self):
96         """Aborts the current transaction."""
97
98         self.cnx.rollback()
99
100
101
102     ### Member-related methods ###
103     
104     def select_members(self, sql, params=None):
105         """
106         Retrieves a list CSC members selected by given SQL statement.
107         
108         This is a helper function that should generally not be called directly.
109         
110         Parameters:
111             sql    - the SELECT sql statement
112             params - parameters for the SQL statement
113
114         The sql statement must select the six columns
115         (memberid, name, studentid, program, type, userid)
116         from the members table in that order.
117         
118         Returns: a memberid-keyed dictionary whose values are
119                  column-keyed dictionaries with member attributes
120         """
121         
122         # retrieve a list of all members
123         try:
124             self.cursor.execute(sql, params)
125             members_list = self.cursor.fetchall()
126         except pgdb.Error, e:
127             raise DBException("SELECT statement failed: %s" % e)
128         
129         # build a dictionary of dictionaries from the result (a list of lists)
130         members_dict = {}
131         for member in members_list:
132             members_dict[member[0]] = {
133                 'memberid': member[0],
134                 'name': member[1],
135                 'studentid': member[2],
136                 'program': member[3],
137                 'type': member[4],
138                 'userid': member[5],
139             }
140
141         return members_dict
142
143
144     def select_single_member(self, sql, params=None):
145         """
146         Retrieves a single member by memberid.
147
148         This is a helper function that should generally not be called directly.
149         
150         See: self.select_members()
151
152         Returns: a column-keyed dictionary with member attributes, or
153                  None if no member matching member exists
154         """
155
156         # retrieve the member
157         results = self.select_members(sql, params)
158
159         # too many members returned
160         if len(results) > 1:
161             raise DBException("multiple members selected: sql='%s' params=%s" % (sql, repr(params)))
162
163         # no such member
164         elif len(results) < 1:
165             return None
166
167         # return the single match
168         memberid = results.keys()[0]
169         return results[memberid]
170
171    
172     def select_all_members(self):
173         """
174         Retrieves a list of all CSC members (past and present).
175
176         See: self.select_members()
177         
178         Example: connection.select_all_members() -> {
179                      0:    { 'memberid': 0, 'name': 'Calum T. Dalek' ...}
180                      3349: { 'memberid': 3349, 'name': 'Michael Spang' ...}
181                      ...
182                  }
183         """
184         sql = "SELECT memberid, name, studentid, program, type, userid FROM members"
185         return self.select_members(sql)
186         
187     
188     def select_members_by_name(self, name_re):
189         """
190         Retrieves a list of all CSC members whose name matches name_re.
191         
192         See: self.select_members()
193         
194         Example: connection.select_members_by_name('Michael') -> {
195                      3349: { 'memberid': 3349, 'name': 'Michael Spang' ...}
196                      ...
197                  }
198         """
199         sql = "SELECT memberid, name, studentid, program, type, userid FROM members WHERE name ~* %s"
200         params = [ str(name_re) ]
201      
202         return self.select_members(sql, params)
203
204     
205     def select_members_by_term(self, term):
206         """
207         Retrieves a list of all CSC members who were members in the specified term.
208         
209         See: self.select_members()
210         
211         Example: connection.select_members_by_term('f2006') -> {
212                      3349: { 'memberid': 3349, 'name': 'Michael Spang' ...}
213                      ...
214                  }
215         """
216         sql = "SELECT members.memberid, name, studentid, program, type, userid FROM members JOIN terms ON members.memberid=terms.memberid WHERE term=%s"
217         params = [ str(term) ]
218         
219         return self.select_members(sql, params)
220
221     
222     def select_member_by_id(self, memberid):
223         """
224         Retrieves a single member by memberid.
225
226         See: self.select_single_member()
227
228         Example: connection.select_member_by_id(0) ->
229                  { 'memberid': 0, 'name': 'Calum T. Dalek' ...}
230         """
231         sql = "SELECT memberid, name, studentid, program, type, userid FROM members WHERE memberid=%d"
232         params = [ int(memberid) ]
233
234         return self.select_single_member(sql, params)
235
236     
237     def select_member_by_userid(self, username):
238         """
239         Retrieves a single member by UNIX account username.
240
241         See: self.select_single_member()
242
243         Example: connection.select_member_by_userid('ctdalek') ->
244                  { 'memberid': 0, 'name': 'Calum T. Dalek' ...}
245         """
246         sql = "SELECT memberid, name, studentid, program, type, userid FROM members WHERE userid=%s"
247         params = [ username ]
248
249         return self.select_single_member(sql, params)
250
251
252     def select_member_by_studentid(self, studentid):
253         """
254         Retrieves a single member by student id number.
255
256         See: self.select_single_member()
257
258         Example: connection.select_member_by_studentid('nnnnnnnn') ->
259                  { 'memberid': 3349, 'name': 'Michael Spang' ...}
260         """
261         sql = "SELECT memberid, name, studentid, program, type, userid FROM members WHERE studentid=%s"
262         params = [ studentid ]
263
264         return self.select_single_member(sql, params)
265
266     
267     def insert_member(self, name, studentid=None, program=None, mtype='user', userid=None):
268         """
269         Creates a member with the specified attributes.
270
271         Parameters:
272             name      - full name of member
273             studentid - student id number
274             program   - program of study
275             mtype     - member type
276             userid    - account id
277
278         Example: connection.insert_member('Michael Spang', '99999999', 'Math/CS') -> 3349
279
280         Returns: a memberid of created user
281         """
282         try:
283             # retrieve the next memberid
284             sql = "SELECT nextval('memberid_seq')"
285             self.cursor.execute(sql)
286             result = self.cursor.fetchone()
287             memberid = result[0]
288         
289             # insert the member
290             sql = "INSERT INTO members (memberid, name, studentid, program, type, userid) VALUES (%d, %s, %s, %s, %s, %s)"
291             params = [ memberid, name, studentid, program, mtype, userid ]
292             self.cursor.execute(sql, params)
293             
294             return memberid
295         except pgdb.Error, e:
296             raise DBException("failed to create member: %s" % e)
297
298     
299     def delete_member(self, memberid):
300         """
301         Deletes a member. Note that a member cannot
302         be deleted until it has been unregistered from
303         all terms.
304
305         Parameters:
306             memberid - the member id number to delete
307
308         Example: connection.delete_member(3349)
309         """
310         sql = "DELETE FROM members WHERE memberid=%d"
311         params = [ memberid ]
312
313         try:
314             self.cursor.execute(sql, params)
315         except pgdb.Error, e:
316             raise DBException("DELETE statement failed: %s" %e)
317
318     
319     def update_member(self, member):
320         """
321         Modifies member attributes.
322
323         Parameters:
324             member - a column-keyed dictionary with the new state of the member.
325                      member['memberid'] must be present. ommitted columns
326                      will not be changed. None is NULL.
327
328         Returns: the full new state of the member as a column-keyed dictionary
329
330         Example: connection.update_member({
331                      'memberid': 3349,
332                      'name': 'Michael C. Spang',
333                      'program': 'CS!'
334                  }) -> {
335                      'memberid': 3349,
336                      'name': 'Michael C. Spang',
337                      'program': CS!',
338                      'studentid': '99999999' # unchanged
339                  }
340
341         Equivalent Example:
342                  member = connection.select_member_by_id(3349)
343                  member['name'] = 'Michael C. Spang'
344                  member['program'] = 'CS!'
345                  connection.update_member(member) -> { see above }
346         """
347         try:
348             
349             # memberid to update
350             memberid = member['memberid']
351             
352             # retrieve current state of member
353             member_state = self.select_member_by_id(memberid)
354
355             # build a list of changes to make
356             changes = []
357             for column in member.keys():
358                 if member[column] != member_state[column]:
359
360                     # column's value must be updated
361                     changes.append( (column, member[column]) )
362                     member_state[column] = member[column]
363             
364             # no changes?
365             if len(changes) < 1:
366                 return member_state
367             
368             # make the necessary changes in an update statement
369             changes = zip(*changes)
370             sql = "UPDATE members SET " + ", ".join(["%s=%%s"] * len(changes[0])) % changes[0] + " WHERE memberid=%d"
371             params = changes[1] + ( memberid, )
372             self.cursor.execute(sql, params)
373
374             return member_state
375         except pgdb.Error, e:
376             raise DBException("member update failed: %s" % e)
377         
378
379
380     ### Term-related methods ###
381
382     def select_term(self, memberid, term):
383         """
384         Determines whether a member is registered for a term.
385         
386         Parameters:
387             memberid - the member id number
388             term     - the term to check
389
390         Returns: a matching term, or None
391
392         Example: connection.select_term(3349, 'f2006') -> 'f2006'
393         """
394         sql = "SELECT term FROM terms WHERE memberid=%d AND term=%s"
395         params = [ memberid, term ]
396
397         # retrieve matches
398         try:
399             self.cursor.execute(sql, params)
400             result = self.cursor.fetchall()
401         except pgdb.Error, e:
402             raise DBException("SELECT statement failed: %s" % e)
403
404         if len(result) > 1:
405             raise DBException("multiple rows in terms with memberid=%d term=%s" % (memberid, term))
406         elif len(result) == 0:
407             return None
408         else:
409             return result[0][0]
410
411
412     def select_terms(self, memberid):
413         """
414         Retrieves a list of terms a member is registered for.
415
416         Parameters:
417             memberid - the member id number
418
419         Returns: a sorted list of terms
420         
421         Example: connection.select_terms(3349) -> ['f2006']
422         """
423         sql = "SELECT term FROM terms WHERE memberid=%d"
424         params = [ memberid ]
425
426         # retrieve the list of terms
427         try:
428             self.cursor.execute(sql, params)
429             result = self.cursor.fetchall()
430         except pgdb.Error, e:
431             raise DBException("SELECT statement failed: %s" % e)
432         
433         result = [ row[0] for row in result ]
434
435         return result
436
437
438     def insert_term(self, memberid, term):
439         """
440         Registers a member for a term.
441
442         Parameters:
443             memberid - the member id number to register
444             term     - string representation of the term
445
446         Example: connection.insert_term(3349, 'f2006')
447         """
448         sql = "INSERT INTO terms (memberid, term) VALUES (%d, %s)"
449         params = [ memberid, term ]
450         
451         try:
452             self.cursor.execute(sql, params)
453         except pgdb.Error, e:
454             raise DBException("INSERT statement failed: %s" % e)
455
456
457     def delete_term(self, memberid, term):
458         """
459         Unregisters a member for a term.
460
461         Parameters:
462             memberid - the member id number to register
463             term     - string representation of the term
464         
465         Example: connection.delete_term(3349, 'f2006')
466         """
467         sql = "DELETE FROM terms WHERE memberid=%d and term=%s"
468         params = [ memberid, term ]
469
470         try:
471             self.cursor.execute(sql, params)
472         except pgdb.Error, e:
473             raise DBException("DELETE statement failed: %s" % e)
474
475     
476     def delete_term_all(self, memberid):
477         """
478         Unregisters a member for all registered terms.
479
480         Parameters:
481             memberid - the member id number to unregister
482         
483         Example: connection.delete_term_all(3349)
484         """
485         sql = "DELETE FROM terms WHERE memberid=%d"
486         params = [ memberid ]
487         
488         # retrieve a list of terms
489         try:
490             self.cursor.execute(sql, params)
491         except pgdb.Error, e:
492             raise DBException("DELETE statement failed: %s" % e)
493
494
495     ### Miscellaneous methods ###
496
497     def trim_memberid_sequence(self):
498         """
499         Sets the value of the member id sequence to the id of the newest
500         member. For use after testing to prevent large intervals of unused
501         memberids from developing.
502
503         Note: this does nothing unless the most recently added member(s) have been deleted
504         """
505         self.cursor.execute("SELECT setval('memberid_seq', (SELECT max(memberid) FROM members))")
506
507
508
509 ### Tests ###
510
511 if __name__ == '__main__':
512
513     from csc.common.test import *
514  
515     conffile = "/etc/csc/pgsql.cf"
516
517     cfg = dict([map(str.strip, a.split("=", 1)) for a in map(str.strip, open(conffile).read().split("\n")) if "=" in a ])
518     hostnm = cfg['server'][1:-1]
519     dbase = cfg['database'][1:-1]
520
521     # t=test m=member s=student d=default e=expected u=updated
522     tmname = 'Test Member'
523     tmuname = 'Member Test'
524     tmsid = '00000004'
525     tmusid = '00000008'
526     tmprogram = 'Undecidable'
527     tmuprogram = 'Nondetermined'
528     tmtype = 'Untyped'
529     tmutype = 'Poly'
530     tmuserid = 'tmem'
531     tmuuserid = 'identifier'
532     tm2name = 'Test Member 2'
533     tm2sid = '00000005'
534     tm2program = 'Undeclared'
535     tm3name = 'T. M. 3'
536     dtype = 'user'
537     tmterm = 'w0000'
538     tm3term = 'f1112'
539     tm3term2 = 's1010'
540
541     emdict = { 'name': tmname, 'program': tmprogram, 'studentid': tmsid, 'type': tmtype, 'userid': tmuserid }
542     emudict = { 'name': tmuname, 'program': tmuprogram, 'studentid': tmusid, 'type': tmutype, 'userid': tmuuserid }
543     em2dict = { 'name': tm2name, 'program': tm2program, 'studentid': tm2sid, 'type': dtype, 'userid': None }
544     em3dict = { 'name': tm3name, 'program': None, 'studentid': None, 'type': dtype, 'userid': None }
545     
546     test(DBConnection)
547     connection = DBConnection()
548     success()
549
550     test(connection.connect)
551     connection.connect(hostnm, dbase)
552     success()
553
554     test(connection.connected)
555     assert_equal(True, connection.connected())
556     success()
557
558     test(connection.insert_member)
559     tmid = connection.insert_member(tmname, tmsid, tmprogram, tmtype, tmuserid)
560     tm2id = connection.insert_member(tm2name, tm2sid, tm2program)
561     tm3id = connection.insert_member(tm3name)
562     assert_equal(True, int(tmid) >= 0)
563     assert_equal(True, int(tmid) >= 0)
564     success()
565
566     emdict['memberid'] = tmid
567     emudict['memberid'] = tmid
568     em2dict['memberid'] = tm2id
569     em3dict['memberid'] = tm3id
570
571     test(connection.select_member_by_id)
572     m1 = connection.select_member_by_id(tmid)
573     m2 = connection.select_member_by_id(tm2id)
574     m3 = connection.select_member_by_id(tm3id)
575     assert_equal(emdict, m1)
576     assert_equal(em2dict, m2) 
577     assert_equal(em3dict, m3)
578     success()
579
580     test(connection.select_all_members)
581     members = connection.select_all_members()
582     assert_equal(True, tmid in members)
583     assert_equal(True, tm2id in members)
584     assert_equal(True, tm3id in members)
585     assert_equal(emdict, members[tmid])
586     success()
587
588     test(connection.select_members_by_name)
589     members = connection.select_members_by_name(tmname)
590     assert_equal(True, tmid in members)
591     assert_equal(False, tm3id in members)
592     assert_equal(emdict, members[tmid])
593     success()
594
595     test(connection.select_member_by_userid)
596     assert_equal(emdict, connection.select_member_by_userid(tmuserid))
597     success()
598
599     test(connection.insert_term)
600     connection.insert_term(tmid, tmterm)
601     connection.insert_term(tm3id, tm3term)
602     connection.insert_term(tm3id, tm3term2)
603     success()
604
605     test(connection.select_members_by_term)
606     members = connection.select_members_by_term(tmterm)
607     assert_equal(True, tmid in members)
608     assert_equal(False, tm2id in members)
609     assert_equal(False, tm3id in members)
610     success()
611
612     test(connection.select_term)
613     assert_equal(tmterm, connection.select_term(tmid, tmterm))
614     assert_equal(None, connection.select_term(tm2id, tmterm))
615     assert_equal(tm3term, connection.select_term(tm3id, tm3term))
616     assert_equal(tm3term2, connection.select_term(tm3id, tm3term2))
617     success()
618
619     test(connection.select_terms)
620     trms = connection.select_terms(tmid)
621     trms2 = connection.select_terms(tm2id)
622     assert_equal([tmterm], trms)
623     assert_equal([], trms2)
624     success()
625
626     test(connection.delete_term)
627     assert_equal(tm3term, connection.select_term(tm3id, tm3term))
628     connection.delete_term(tm3id, tm3term)
629     assert_equal(None, connection.select_term(tm3id, tm3term))
630     success()
631
632     test(connection.update_member)
633     connection.update_member({'memberid': tmid, 'name': tmuname})
634     connection.update_member({'memberid': tmid, 'program': tmuprogram, 'studentid': tmusid })
635     connection.update_member({'memberid': tmid, 'userid': tmuuserid, 'type': tmutype })
636     assert_equal(emudict, connection.select_member_by_id(tmid))
637     connection.update_member(emdict)
638     assert_equal(emdict, connection.select_member_by_id(tmid))
639     success()
640
641     test(connection.delete_term_all)
642     connection.delete_term_all(tm2id)
643     connection.delete_term_all(tm3id)
644     assert_equal([], connection.select_terms(tm2id))
645     assert_equal([], connection.select_terms(tm3id))
646     success()
647
648     test(connection.delete_member)
649     connection.delete_member(tm3id)
650     assert_equal(None, connection.select_member_by_id(tm3id))
651     negative(connection.delete_member, (tmid,), DBException, "delete of term-registered member")
652     success()
653
654     test(connection.rollback)
655     connection.rollback()
656     assert_equal(None, connection.select_member_by_id(tm2id))
657     success()
658
659     test(connection.commit)
660     connection.commit()
661     success()
662
663     test(connection.trim_memberid_sequence)
664     connection.trim_memberid_sequence()
665     success()
666
667     test(connection.disconnect)
668     connection.disconnect()
669     assert_equal(False, connection.connected())
670     connection.disconnect()
671     success()