diff --git a/access/course.py b/access/course.py index 1b0c7d6..fc5e3e3 100644 --- a/access/course.py +++ b/access/course.py @@ -21,6 +21,15 @@ LOGGER = logging.getLogger('main') +def _get_datetime(value: Any) -> Optional[datetime]: + """Turns date/datetime into a datetime and returns None if given anything else""" + if isinstance(value, date): + return datetime.combine(value, datetime.max.time()) + elif not isinstance(value, datetime): + return None + return value + + class ConfigureOptions(PydanticModel): files: Dict[str,str] = {} # ellipsis (...) makes the field required in the case that a default url isn't specified @@ -354,6 +363,27 @@ def name_or_title(cls, values: Dict[str, Any]): values["name"] = values.pop("title") return values + @root_validator(allow_reuse=True, skip_on_failure=True) + def validate_dates(cls, values: Dict[str, Any]) -> Dict[str, Any]: + close = values.get("close") + close_dt = _get_datetime(close) + + open = values.get("open") + open_dt = _get_datetime(open) + if open_dt and close_dt and open_dt > close_dt: + raise ValueError(f"Module 'close' ({close}) before 'open' ({open})") + + late_close = values.get("late_close") + late_close_dt = _get_datetime(late_close) + if late_close_dt and close_dt and late_close_dt < close_dt: + raise ValueError(f"'late_close' ({late_close}) is before 'close' ({close})") + + read_open = values.get("read_open") + read_open_dt = _get_datetime(read_open) + if read_open_dt and open_dt and read_open_dt >= open_dt: + raise ValueError(f"'read_open' ({read_open}) is not before 'open' ({open})") + + return values NumberingType = Literal["none", "arabic", "roman", "hidden"] @@ -469,24 +499,26 @@ def validate_keys(cls, values: Dict[str, Any]) -> Dict[str, Any]: @root_validator(allow_reuse=True, skip_on_failure=True) def validate_module_dates(cls, values: Dict[str, Any]) -> Dict[str, Any]: - def get_datetime(value): - if isinstance(value, date): - return datetime.combine(value, datetime.max.time()) - elif not isinstance(value, datetime): - return None - return value - - end_date = get_datetime(values.get("end")) - for m in values["modules"]: - close = get_datetime(m.close) - if close and end_date and close > end_date: - m.add_warning(f"Course ends before module closes") - - late_close = get_datetime(m.late_close) - if late_close: - if close: - if late_close < close: - m.add_warning(f"'late_close' is before 'close'") - elif end_date and late_close < end_date: - m.add_warning(f"'late_close' is before module close (which defaults to course 'end')") + end = values.get("end") + end_dt = _get_datetime(end) + for i, m in enumerate(values["modules"]): + close_dt = _get_datetime(m.close) + if close_dt and end_dt and close_dt > end_dt: + m.add_warning(f"Course 'end' ({end}) before module {i} 'close' ({m.close})", "close") + + late_close_dt = _get_datetime(m.late_close) + if late_close_dt and not close_dt and end_dt and late_close_dt < end_dt: + raise ValueError(f"Module {i} 'late_close' ({m.late_close}) is before 'close' ({end}, defaulted to course 'end')") + return values + + @root_validator(allow_reuse=True, skip_on_failure=True) + def validate_course_dates(cls, values: Dict[str, Any]) -> Dict[str, Any]: + start = values.get("start") + end = values.get("end") + start_dt = _get_datetime(start) + end_dt = _get_datetime(end) + + if start_dt and end_dt and start_dt > end_dt: + raise ValueError(f"Course 'end' ({end}) before 'start' ({start})") + return values