diff --git a/slotmachine/__init__.py b/slotmachine/__init__.py index a2ef3a0..f82e8ce 100644 --- a/slotmachine/__init__.py +++ b/slotmachine/__init__.py @@ -29,6 +29,7 @@ class Talk: preferred_venues: set[VenueID] = field(default_factory=set) allowed_slots: set[Slot] = field(default_factory=set) preferred_slots: set[Slot] = field(default_factory=set) + must_schedule_after: set[TalkID] = field(default_factory=set) class SlotMachine(object): @@ -87,6 +88,34 @@ def active(self, slot: Slot, talk_id: TalkID, venue: VenueID) -> pulp.LpVariable self.var_cache[name] = variable return variable + def start_slot(self, talk_id: TalkID) -> pulp.LpVariable: + """A variable that is the number of the slot that talk ID is scheduled to begin.""" + name = "S_start_%d" % (talk_id,) + if name in self.var_cache: + return self.var_cache[name] + + variable = pulp.LpVariable(name, cat="Integer") + definition = pulp.lpSum( + self.start_var(slot, talk_id, venue) * slot + for slot in self.talks_by_id[talk_id].allowed_slots + for venue in self.talks_by_id[talk_id].venues + ) + self.problem.addConstraint(variable == definition) + self.var_cache[name] = variable + return variable + + def end_slot(self, talk_id: TalkID) -> pulp.LpVariable: + """A variable that is the number of the slot that talk ID is scheduled to begin.""" + name = "S_end_%d" % (talk_id,) + if name in self.var_cache: + return self.var_cache[name] + + variable = pulp.LpVariable(name, cat="Integer") + definition = self.start_slot(talk_id) + self.talks_by_id[talk_id].duration + self.problem.addConstraint(variable == definition) + self.var_cache[name] = variable + return variable + def get_problem( self, venues: set[VenueID], talks: list[Talk], old_talks: list[OldTalk] ) -> pulp.LpProblem: @@ -107,6 +136,13 @@ def get_problem( == 1 ) + # Talks which must precede other talks do that + for talk in talks: + for schedule_before in talk.must_schedule_after: + self.problem.addConstraint( + self.end_slot(schedule_before) <= self.start_slot(talk.id) + ) + # At most one talk may be active in a given venue and slot. for v in venues: for slot in self.slots_available: diff --git a/tests/test_slotmachine.py b/tests/test_slotmachine.py index e026546..50eb88f 100644 --- a/tests/test_slotmachine.py +++ b/tests/test_slotmachine.py @@ -24,6 +24,7 @@ def talk( venues: list[int], speakers: list[str], slots: Iterable[Slot] | Iterable[int], + must_schedule_after: list[int] = [], ) -> Talk: return Talk( id=TalkID(id), @@ -31,6 +32,7 @@ def talk( venues={VenueID(vid) for vid in venues}, speakers=speakers, allowed_slots={Slot(s) for s in slots}, + must_schedule_after={TalkID(t) for t in must_schedule_after}, ) @@ -253,3 +255,41 @@ def test_talk_clash(self): # Talk 1 must now be in slot 3 or 4 self.assertIn(talks_slots[1], [3, 4]) + + def test_must_schedule_after(self): + avail_slots = SlotMachine.calculate_slots( + parser.parse("2016-08-06 13:00"), + parser.parse("2016-08-06 13:00"), + parser.parse("2016-08-06 15:00"), + ) + _talk = partial(talk, slots=avail_slots[:], venues=[101]) + talk_defs = [ + _talk( + id=1, duration=3 + 1, speakers=["Speaker 1"], must_schedule_after=[2] + ), + _talk( + id=2, duration=2 + 1, speakers=["Speaker 2"], must_schedule_after=[3] + ), + _talk( + id=3, duration=2 + 1, speakers=["Speaker 3"], must_schedule_after=[4] + ), + _talk(id=4, duration=2 + 1, speakers=["Speaker 4"]), + ] + old_talks = [(0, 1, 101), (3, 2, 101), (6, 3, 101), (9, 4, 101)] + solved = self.schedule_and_basic_asserts( + talk_defs, avail_slots, old_talks=old_talks + ) + + slots, talks, venues = unzip(solved) + talks_slots = dict(zip(talks, slots)) + + # The talks are now in the reverse order of the one they were in in old_talks. + self.assertEqual( + { + 4: 0, + 3: 3, + 2: 6, + 1: 9, + }, + talks_slots, + )