"""
db.py

Database interface for grading utilities
"""
# This file is part of the schoolutils package.
# Copyright (C) 2013 Richard Lawrence <richard.lawrence@berkeley.edu>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.

import sys, sqlite3, datetime

class GradeDBException(Exception):
    def __init__(self, error_str, query=None, params=None): 
        self.query = query
        self.params = params
        self.error_str = error_str

    def __str__(self):
        return self.error_str

    def print_query(self):
        "Print this exception's query and parameter to stderr"
        sys.stderr.write(self.query)
        sys.stderr.write(str(self.params) + '\n')
    
class NoRecordsFound(GradeDBException):
    pass

class MultipleRecordsFound(GradeDBException):
    pass

def connect(path):
    """Create a connection to a grade database at the given path.
       Returns a sqlite3.Connection object appropriately initialized
         for the grading application.
    """
    conn = sqlite3.connect(path)
    conn.row_factory = sqlite3.Row

    return conn
    
def gradedb_init(db_connection):
    """Create a new SQLite database for storing grades.
       Creates a database with tables:
         students (id, first_name, last_name, sid, email)
         courses (id, name, number, year, semester)
         course_memberships (id, student_id, course_id)
         assignments (id, course_id, name, description, due_date, grade_type, points, weight)
         grades (id, assignment_id, student_id, value, timestamp)
       db_connection should be a sqlite database connection.
    """
    db_connection.executescript("""
    CREATE TABLE students (
      id INTEGER PRIMARY KEY,
      first_name TEXT,
      last_name TEXT,
      sid TEXT UNIQUE,
      email TEXT
    );
    CREATE TABLE courses (
      id INTEGER PRIMARY KEY,
      name TEXT,
      number TEXT,
      year INTEGER,
      semester TEXT
    );
    CREATE TABLE course_memberships (
      id INTEGER PRIMARY KEY,
      student_id INTEGER NOT NULL,
      course_id INTEGER NOT NULL,
      FOREIGN KEY(student_id) REFERENCES students(id),
      FOREIGN KEY(course_id) REFERENCES courses(id),
      UNIQUE(student_id, course_id) ON CONFLICT IGNORE
    );
    CREATE TABLE assignments (
      id INTEGER PRIMARY KEY,
      course_id INTEGER NOT NULL,
      name TEXT,
      description TEXT,
      due_date TEXT,
      grade_type TEXT,
      weight NUMERIC,
      FOREIGN KEY(course_id) REFERENCES courses(id)
    );
    CREATE TABLE grades (
      id INTEGER PRIMARY KEY,
      assignment_id INTEGER NOT NULL,
      student_id INTEGER NOT NULL,
      -- rely on SQLite's dynamic types to store letter grades as text:
      value NUMERIC,
      timestamp TEXT,
      FOREIGN KEY(assignment_id) REFERENCES assignments(id),
      FOREIGN KEY(student_id) REFERENCES students(id)
    );
    """)
    return db_connection.commit()
    
def insert_sample_data(db_connection):
    "Insert some sample data into a grade database"
    db_connection.executescript("""
    INSERT INTO students VALUES (1, 'Richard', 'Lawrence', '98765432', 'richard@example.com');
    INSERT INTO students VALUES (2, 'Austin', 'Powers', '12345678', 'austin@example.com');
    INSERT INTO courses VALUES (1, 'Ancient philosophy', '25A', 2012, 'Fall');
    INSERT INTO courses VALUES (2, 'Introduction to logic', '12A', 2012, 'Spring');
    INSERT INTO course_memberships VALUES (1, 1, 1);
    INSERT INTO course_memberships VALUES (2, 2, 1);
    INSERT INTO course_memberships VALUES (3, 1, 2);
    INSERT INTO assignments VALUES (1, 1, 'Paper 1', 'Socrates paper', '2012-09-17', 'letter', 0.25);
    INSERT INTO assignments VALUES (2, 1, 'Paper 2', 'Plato paper', '2012-10-30', 'letter', 0.25);
    INSERT INTO assignments VALUES (3, 1, 'Paper 3', 'Aristotle paper', '2012-11-26', 'letter', 0.25);
    INSERT INTO assignments VALUES (4, 1, 'Exam grade', 'Final exam', '2012-12-14', 'letter', 0.25);
    INSERT INTO assignments VALUES (5, 2, 'HW1', 'problem set', '2012-01-29', 'points', 105);
    INSERT INTO assignments VALUES (6, 2, 'HW2', 'problem set', '2012-02-05', 'points', 96);
    INSERT INTO grades VALUES (1, 1, 1, 'C-', '1111111111');
    INSERT INTO grades VALUES (2, 2, 1, 'B-', '1111111111');
    INSERT INTO grades VALUES (3, 3, 1, 'A', '1111111111');
    INSERT INTO grades VALUES (4, 4, 1, 'B+', '1111111111');
    INSERT INTO grades VALUES (5, 1, 2, 'A', '1111111113');
    INSERT INTO grades VALUES (6, 2, 2, 'A', '1111111113');
    INSERT INTO grades VALUES (7, 3, 2, 'A', '1111111113');
    INSERT INTO grades VALUES (8, 4, 2, 'A', '1111111113');
    INSERT INTO grades VALUES (9, 5, 1, 104, '1111111113');
    INSERT INTO grades VALUES (10, 6, 1, 90, '1111111113');
    """)
    return db_connection.commit()

def gradedb_clear(db_connection):
    "Drop all tables in a grade database"
    db_connection.executescript("""
    DROP TABLE students;
    DROP TABLE courses;
    DROP TABLE course_memberships;
    DROP TABLE assignments;
    DROP TABLE grades;
    """)
    return db_connection.commit()
    
#
# basic CRUD operations and some convenience interfaces
#
def select_courses(db_connection, course_id=None, year=None, semester=None,
                   name=None, number=None, student_id=None):
    """Return a result set of courses.
       Rows in the result set have the format:
       (id, name, number, year, semester)
    """
    if not student_id:
        # don't perform a join without student information
        base_query = """
        SELECT id, name, number, year, semester
        FROM courses
        %(where)s
        """
    else:
        base_query = """
        SELECT courses.id, courses.name, courses.number, courses.year, courses.semester
        FROM courses, course_memberships, students
        ON (course_memberships.student_id=students.id AND
            course_memberships.course_id=courses.id)
        %(where)s
        """
        
    constraints, params = make_conjunction_clause(
        ['courses.id', 'courses.year', 'courses.semester', 'courses.name',
         'courses.number', 'students.id'],
        [course_id, year, semester, name, number, student_id])
    query = add_where_clause(base_query, constraints)
    
    return db_connection.execute(query, params).fetchall()

def create_course(db_connection, year=None, semester=None, name=None,
                  number=None):
    """Create a new course in the database.
       Returns the id of the inserted row.
    """
    base_query = """
    INSERT INTO courses (%(fields)s) VALUES (%(places)s);
    """
    fields, places, params = make_values_clause(
        ['year', 'semester', 'name', 'number'],
        [year, semester, name, number])
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)

    return last_insert_rowid(db_connection)

def create_or_update_course(db_connection, course_id=None, year=None,
                            semester=None, name=None, number=None):
    """Create a new course or update a record of an existing course.
       Returns the id of the created or updated row.

       WARNING: This function uses SQLite's INSERT OR REPLACE
       statement rather than an UPDATE statement.  If you pass
       course_id, it *will* erase data in an existing row of the
       courses table on a conflict; you must provide all values to
       replace the existing data.
    """
    base_query = """
    INSERT OR REPLACE INTO courses (%(fields)s) VALUES (%(places)s);
    """
    fields, places, params = make_values_clause(
        ['id', 'year', 'semester', 'name', 'number'],
        [course_id, year, semester, name, number])
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)

    return last_insert_rowid(db_connection)

def delete_course_etc(db_connection, course_id=None):
    """Delete a course and all associated rows.
       Returns number of deleted course rows.

       WARNING: this function deletes all assignments, grades, and memberships
       associated with this course.
       course_id is required; this function will not delete more than
       one course.
    """
    if not course_id:
        raise ValueError("course_id is required to delete course row.")
    
    grades_query = """
    DELETE FROM grades WHERE assignment_id IN
      (SELECT id FROM assignments WHERE course_id=?);
    """
    assignments_query = """
    DELETE FROM assignments WHERE course_id=?;
    """
    members_query = """
    DELETE FROM course_memberships WHERE course_id=?;
    """
    course_query = """
    DELETE FROM courses WHERE id=?;
    """
    params = (course_id,)
    db_connection.execute(grades_query, params)
    db_connection.execute(assignments_query, params)
    db_connection.execute(members_query, params)
    db_connection.execute(course_query, params)

    return num_changes(db_connection)   
    
def select_assignments(db_connection, assignment_id=None, course_id=None,
                       year=None, semester=None, name=None):
    """Return a result set of assignments.
       The rows in the result set have the format:
       (assignment_id, course_id, assignment_name, due_date, grade_type, weight,
         description)
    """
    base_query = """
    SELECT assignments.id, courses.id AS course_id,
           assignments.name, assignments.due_date,
           assignments.grade_type, assignments.weight, assignments.description
    FROM assignments, courses
    ON assignments.course_id=courses.id
    %(where)s
    ORDER BY CASE WHEN assignments.weight='CALC' THEN 1 ELSE 0 END,
             assignments.due_date ASC;
    """
    constraints, params = make_conjunction_clause(
        ['assignments.id', 'courses.year', 'courses.semester',
         'courses.id', 'assignments.name'],
        [assignment_id, year, semester, course_id, name])
    query = add_where_clause(base_query, constraints)
    
    return db_connection.execute(query, params).fetchall()

def create_assignment(db_connection, course_id=None, name=None, description=None,
                      due_date=None, grade_type=None, weight=None):
    """Create a new assignment in the database.
       Returns the id of the inserted row.
    """
    base_query = """
    INSERT INTO assignments (%(fields)s) VALUES (%(places)s);
    """
    fields, places, params = make_values_clause(
        ['course_id', 'name', 'description', 'due_date', 'grade_type', 'weight'],
        [course_id, name, description, due_date, grade_type, weight])
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)

    return last_insert_rowid(db_connection)
  
def create_or_update_assignment(db_connection, assignment_id=None,
                                course_id=None, name=None, description=None,
                                due_date=None, grade_type=None, weight=None):
    """Create a new assignment or update a record of an existing assignment.
       Returns the id of the created or updated row.

       WARNING: This function uses SQLite's INSERT OR REPLACE
       statement rather than an UPDATE statement.  If you pass
       assignment_id, it *will* erase data in an existing row of the
       assignments table on a conflict; you must provide all values to
       replace the existing data.
    """
    base_query = """
    INSERT OR REPLACE INTO assignments (%(fields)s) VALUES (%(places)s);
    """
    fields, places, params = make_values_clause(
        ['id', 'course_id', 'name', 'description', 'due_date',
         'grade_type', 'weight'],
        [assignment_id, course_id, name, description, due_date,
         grade_type, weight])
    
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)
    
    return last_insert_rowid(db_connection)

def delete_assignment_and_grades(db_connection, assignment_id=None):
    """Delete an assignment and all associated grades.
       Returns number of deleted assignment rows.
       
       assignment_id is required; this function will not delete more than
       one assignment.
    """
    if not assignment_id:
        raise ValueError("assignment_id is required to delete assignment row.")
    
    query1 = "DELETE FROM grades WHERE assignment_id=?;"
    query2 = "DELETE FROM assignments WHERE id=?;"
    params = (assignment_id,)
    db_connection.execute(query1, params)
    db_connection.execute(query2, params)

    return num_changes(db_connection)

def select_students(db_connection, student_id=None, year=None, semester=None,
                    course_id=None, course_name=None, last_name=None,
                    first_name=None, sid=None, email=None,
                    fuzzy=False):
    """Return a result set of students.
       The rows in the result set have the format:
       (student_id, last_name, first_name, sid, email)
       If fuzzy is True, this function will use SQLite's LIKE clause to perform 
         case-insensitive fuzzy matching on last_name, first_name, email, and
         course_name fields 
    """
    if course_id or course_name: 
        base_query = """
        SELECT students.id, students.last_name, students.first_name,
               students.sid, students.email
        FROM students, course_memberships, courses
        ON (course_memberships.student_id=students.id AND
            course_memberships.course_id=courses.id)
        %(where)s
        ORDER BY students.last_name ASC, students.first_name ASC
        """
    else:
        # don't perform a join without any course information to constrain the query:
        # that leads to duplicate results! (and it's slower)
        base_query = """
        SELECT students.id, students.last_name, students.first_name,
               students.sid, students.email
        FROM students
        %(where)s
        ORDER BY students.last_name ASC, students.first_name ASC
        """

    exact_fields = ['courses.year', 'courses.semester', 'courses.id',
                    'students.id', 'students.sid']
    exact_vals = [year, semester, course_id, student_id, sid]
    fuzzy_fields = ['students.last_name', 'students.first_name', 'students.email',
                    'courses.name']
    fuzzy_vals = [last_name, first_name, email, course_name]
       
    if not fuzzy:
        exact_fields = exact_fields + fuzzy_fields
        exact_vals = exact_vals + fuzzy_vals
        fuzzy_fields = fuzzy_vals = []
    else:    
        # for now, just assume that we should glob on both left and
        # right of every field with a LIKE constraint
        add_glob = lambda s: '%' + s + '%' if s else s
        fuzzy_vals = [add_glob(v) for v in fuzzy_vals]
        
    constraints, params = make_conjunction_clause(exact_fields, exact_vals)
    constraints, params = make_conjunction_clause(fuzzy_fields, fuzzy_vals,
                                                  extra=constraints,
                                                  extra_params=params,
                                                  cmp_op="LIKE")

    query = add_where_clause(base_query, constraints)
    
    return db_connection.execute(query, params).fetchall()

def get_student_id(db_connection, first_name=None, last_name=None,
                   sid=None, email=None):
    """Find a student in the grade database.
       Searches by (last_name, first_name) OR sid OR email.
       Return the student's id if found uniquely.
    """
    base_query = """
    SELECT id
    FROM students
    %(where)s
    """
    name_constraints, name_params = make_conjunction_clause(
        ['first_name', 'last_name'],
        [first_name, last_name])
    constraints, params = make_disjunction_clause(
        ['sid', 'email'],
        [sid, email],
        extra=name_constraints, extra_params=name_params)

    query = add_where_clause(base_query, constraints)
    
    rows = db_connection.execute(query, params).fetchall()

    return ensure_unique(
        rows,
        err_msg="get_student_id expects to find exactly 1 student",
        query=query, params=params)
    
def create_student(db_connection, first_name=None, last_name=None, sid=None,
                   email=None):
    """Create a new student in the database.
       Returns the id of the inserted row.
    """
    base_query = """
    INSERT INTO students (%(fields)s) VALUES (%(places)s);
    """
    fields, places, params = make_values_clause(
        ['first_name', 'last_name', 'sid', 'email'],
        [first_name, last_name, sid, email])
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)

    return last_insert_rowid(db_connection)

def update_student(db_connection, student_id=None, last_name=None,
                   first_name=None, sid=None, email=None):
    """Update a record of an existing student.
    
       If student_id is not provided, this function attempts to find a
       unique existing student using get_student_id with the given
       criteria, and then updates the database with whatever
       information has been provided, by overlaying the given criteria
       with the existing data.  (That is, it will not replace existing
       data with a NULL value.)  Returns the id of the updated row.
    """
    if not student_id:
        student_id = get_student_id(
            db_connection,
            last_name=last_name, first_name=first_name, sid=sid)

    fields = ['last_name', 'first_name', 'sid', 'email']
    old_values = db_connection.execute(
        "SELECT %s FROM students WHERE id=?" % ', '.join(fields),
        (student_id,)).fetchone()
    new_values = [last_name, first_name, sid, email]
    update_values = overlay(old_values, new_values)
    
    base_query = """
    UPDATE students
    SET %(updates)s
    WHERE %(where)s;
    """
    update_clause, params = make_constraint_clause(", ", fields, update_values)
    query = base_query % {'updates': update_clause,
                          'where': "id=?"}
    db_connection.execute(query, params + (student_id,))
    
    return student_id

def create_or_update_student(db_connection, student_id=None, last_name=None,
                             first_name=None, sid=None, email=None):
    """Create a new student or update a record of an existing student.
       Returns the id of the created or updated row.

       WARNING: This function uses SQLite's INSERT OR REPLACE
       statement rather than an UPDATE statement.  If you pass
       student_id or sid, it *will* erase data in an existing row of
       the students table on a conflict; you must provide all values
       to replace the existing data.
    """
    base_query = """
    INSERT OR REPLACE INTO students (%(fields)s) VALUES (%(places)s);
    """
    fields, places, params = make_values_clause(
        ['id', 'last_name', 'first_name', 'sid', 'email'],
        [student_id, last_name, first_name, sid, email])    
    
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)
    
    return last_insert_rowid(db_connection)

def select_course_memberships(db_connection, member_id=None, course_id=None,
                              student_id=None):
    """Return a result set of course memberships.
       The rows in the result set have the format:
         (course_membership_id, course_id, student_id)
       For joins with students or courses table, see select_students and
         select_courses.
    """
    base_query = """
    SELECT id, course_id, student_id
    FROM course_memberships
    %(where)s
    """
    constraints, params = make_conjunction_clause(
        ['id', 'course_id', 'student_id'],
        [member_id, course_id, student_id])
    query = add_where_clause(base_query, constraints)

    return db_connection.execute(query, params).fetchall()
    
def create_course_member(db_connection, course_id=None, student_id=None):
    """Create a new course_membership record in the database.
       Returns the id of the inserted row.
    """
    base_query = """
    INSERT INTO course_memberships (%(fields)s) VALUES (%(places)s);
    """
    fields, places, params = make_values_clause(
        ['course_id', 'student_id'],
        [course_id, student_id])
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)

    return last_insert_rowid(db_connection)

def delete_course_member(db_connection, member_id=None, course_id=None,
                         student_id=None):
    """Delete a course_membership record in the database.
       Deletes any row from course_memberships where either:
         id = member_id, OR
         (course_id = course_id AND student_id = student_id)
       You must pass either member_id or both course_id and student_id;
         this function refuses to delete multiple rows.
       To delete multiple rows, see delete_course_members.
       Returns the number of deleted rows (which should not exceed 1).
    """
    # sanity check:
    if not (member_id or (student_id and course_id)):
        raise sqlite3.IntegrityError(
            "delete_course_member requires either member_id or BOTH "
            "student_id and course_id")

    base_query = """
    DELETE FROM course_memberships
    %(where)s;
    """
    constraints, params = make_conjunction_clause(
        ['course_id', 'student_id'],
        [course_id, student_id])
    constraints, params = make_disjunction_clause(
        ['id'], [member_id],
        extra=constraints, extra_params=params)
    query = add_where_clause(base_query, constraints)

    db_connection.execute(query, params)
    return num_changes(db_connection)

def delete_course_members(db_connection, member_id=None, course_id=None,
                          student_id=None):
    """Delete one or more course_membership records in the database.
       Deletes any row matching the conjunction of the given criteria.
       Refuses to delete rows if no constraints are provided; if you
         wish to delete all rows in the table, use a custom
         DELETE FROM or DROP TABLE statement.
       Returns the number of deleted rows.
    """
    # sanity check:
    if not (member_id or course_id or student_id):
        raise sqlite3.IntegrityError(
            "delete_course_members will not delete all rows in the "
            "course_memberships table")
    
    base_query = """
    DELETE FROM course_memberships
    %(where)s;
    """
    constraints, params = make_conjunction_clause(
        ['id', 'course_id', 'student_id'],
        [member_id, course_id, student_id])
    query = add_where_clause(base_query, constraints)

    db_connection.execute(query, params)

    return num_changes(db_connection)
    
def select_grades(db_connection, grade_id=None, student_id=None,
                  course_id=None, assignment_id=None):
    """Get a result set of grades for a given student or course.
       The rows in the result set have the format:
       (grade_id, student_id, course_id, assignment_id, assignment_name,
         grade_value)
       course_id may be supplied to limit results to one course.
    """
    base_query = """
    SELECT grades.id,
           students.id AS student_id,
           assignments.course_id AS course_id, assignments.id AS assignment_id,
           assignments.name AS assignment_name,
           grades.value
    FROM grades, assignments, students
    ON grades.assignment_id=assignments.id AND grades.student_id=students.id
    %(where)s
    """
     
    constraints, params = make_conjunction_clause(
        ['grades.id', 'students.id', 'assignments.course_id', 'assignments.id'],
        [grade_id, student_id, course_id, assignment_id])
    query = add_where_clause(base_query, constraints)
   
    return db_connection.execute(query, params).fetchall()

def select_grades_for_course_members(db_connection, student_id=None, course_id=None):
    """Select grades for members of a given course, for all assignments in that course.
       The purpose of this function is to return a result set which contains all the
       information necessary for calculating grades in simple cases.

       This function does two things differently than select_grades:
       1) The result set contains, for every course member, a row for every
          assignment in the course, regardless of whether the student has a grade
          for that assignment or not.  (If a student does not have a grade for a
          given assignment, the grades.value field is simply NULL.)
       2) The result set contains additional fields necessary for calculating grades,
          namely, assignments.weight, assignments.grade_type.

       The result set has the following columns:
       assignment_id, assignment_name, weight, grade_type, grade_id, student_id, value
    """
    base_query = """
    SELECT assignments.id AS assignment_id,
           assignments.name AS assignment_name,
           assignments.weight,
           assignments.grade_type,
           course_memberships.student_id,
           grades.id AS grade_id,
           grades.value
    FROM (course_memberships, assignments USING (course_id))
         LEFT OUTER JOIN grades ON (course_memberships.student_id=grades.student_id AND assignments.id=grades.assignment_id)
    %(where)s;
    """

    constraints, params = make_conjunction_clause(
        ['course_memberships.course_id', 'course_memberships.student_id'],
        [course_id, student_id])
    query = add_where_clause(base_query, constraints)

    return db_connection.execute(query, params).fetchall()

def create_grade(db_connection, assignment_id=None, student_id=None, value=None,
                 timestamp=None):
    """Create a new grade in the database.
       The timestamp field is automatically generated if not provided.
       Returns the id of the inserted row.
    """
    if not timestamp:
        timestamp = datetime.datetime.now()

    base_query = """
    INSERT INTO grades (%(fields)s) VALUES (%(places)s);
    """
    fields, places, params = make_values_clause(
        ['assignment_id', 'student_id', 'value', 'timestamp'],
        [assignment_id, student_id, value, timestamp])
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)

    return last_insert_rowid(db_connection)

def create_or_update_grade(db_connection, grade_id=None, assignment_id=None,
                           student_id=None, value=None, timestamp=None):
    """Create a new grade or update a record of an existing grade.
       Returns the id of the created or updated row.
       
       WARNING: This function uses SQLite's INSERT OR REPLACE
       statement rather than an UPDATE statement.  If you pass
       grade_id, it *will* erase data in an existing row of the grades
       table; you must provide all values to replace the existing data. 
    """
    base_query = """
    INSERT OR REPLACE INTO grades (%(fields)s) VALUES (%(places)s);
    """
    if not timestamp:
        timestamp = datetime.datetime.now()
        
    fields, places, params = make_values_clause(
        ['id', 'assignment_id', 'student_id', 'value', 'timestamp'],
        [grade_id, assignment_id, student_id, value, timestamp])    
    
    query = base_query % {'fields': fields, 'places': places}
    db_connection.execute(query, params)
    
    return last_insert_rowid(db_connection)

def update_grade(db_connection, grade_id=None, value=None):
    """Update a record of an existing grade.
       Returns the id of the updated row.
       
       This function is, for now, intentionally hobbled: you can only
       update a grade's value field, and you can only select a grade
       by its id field.  (Thus you may only update one grade.)  The
       timestamp will be automatically updated.
    """
    if not grade_id:
        raise sqlite3.IntegrityError("grade_id is required to update a grade")
    if not value:
        raise sqlite3.IntegrityError("value is required to update a grade")
    
    query = """
    UPDATE grades
    SET value=?, timestamp=?
    WHERE id=?;
    """
    params = (value, datetime.datetime.now(), grade_id)
    db_connection.execute(query, params)
    
    return grade_id

   
           
#            
# utilities
#
def ensure_unique(rows, err_msg='', query='', params=None):
    "Ensure a set of rows contains a single value and returns it"
    if len(rows) == 0:
        raise NoRecordsFound(err_msg, query=query, params=params)
    elif len(rows) > 1:
        raise MultipleRecordsFound(err_msg, query=query, params=params)
    else:
        return rows[0][0]

def overlay(old_values, new_values):
    """Overlay a new set of values on an old set.
       If old_values[i] != new_values[i], uses new_values[i],
       except if new_values[i] evaluates to False.
    """
    overlaid_vals = []
    for ov, nv in zip(old_values, new_values):
        if not nv or ov == nv:
            overlaid_vals.append(ov)
        else:
            overlaid_vals.append(nv)

    return overlaid_vals

def make_constraint_clause(connective, fields, values,
                           extra='', extra_params=tuple(),
                           cmp_op="="):
    """Construct a constraint clause and a set of parameters for it.
       Returns a tuple of the constaint clause as a string and the
       parameter values as a tuple.

       If provided, extra should be a string to prepend in parentheses
       to the generated constraint clause, and extra_params should be
       a tuple of parameters to prepend to the generated parameters.
       Using these arguments, one can incrementally construct complex
       constraints, e.g. "(field1=? AND field2=?) OR field3=?"

       cmp_op, if provided, should be a string specifying a binary
       comparison operator.  The default is "="; "!=", "LIKE" etc.
       are other useful options.  Spaces are automatically added on
       either side and a parameter place '?' is added to the right
       hand side
    """
    constraints = []
    params = []
    for f, v in zip(fields, values):
        if v:
            constraints.append(f + " " + cmp_op + " ?")
            params.append(v)

    if extra:
        constraints.insert(0, "(" + extra + ")")
        
    clause = connective.join(constraints)
    
    return clause, tuple(extra_params) + tuple(params)

def make_conjunction_clause(fields, values,
                            extra='', extra_params=tuple(),
                            cmp_op="="):
    "Construct a conjunctive constraint clause with make_constraint_clause"
    return make_constraint_clause(" AND ", fields, values,
                                  extra=extra, extra_params=extra_params,
                                  cmp_op=cmp_op)

def make_disjunction_clause(fields, values,
                            extra='', extra_params=tuple(),
                            cmp_op="="):
    "Construct a disjunctive constraint clause with make_constraint_clause"
    return make_constraint_clause(" OR ", fields, values,
                                  extra=extra, extra_params=extra_params,
                                  cmp_op=cmp_op)

def add_where_clause(base_query, constraints):
    """Add a WHERE clause to a query if there are any constraints.
       base_query should be a dictionary-style format string
       containing the format specifier %(where)s and constraints
       should be a string of field constraints.
    """
    if constraints:
        where_clause = "WHERE " + constraints
        query = base_query % {'where': where_clause}
    else:
        query = base_query % {'where': ''}

    return query

def make_values_clause(fields, values):
    """Construct strings of field names and query parameter places, and
       a tuple of parameters"""
    used_fields = []
    params = []
    places = []
    
    for i, v in enumerate(values):
        if v:
            used_fields.append(fields[i])
            params.append(values[i])
            places.append('?')
    
    return ', '.join(used_fields), ', '.join(places), tuple(params)

def last_insert_rowid(db_connection):
    "Returns the id of the last inserted row"
    return ensure_unique(
        db_connection.execute("SELECT last_insert_rowid()").fetchall())

def num_changes(db_connection):
    "Returns the number of rows affected by the last INSERT, UPDATE, or DELETE"
    return ensure_unique(
        db_connection.execute("SELECT changes()").fetchall())

