Skip to content

Commit

Permalink
add additional checks to provide error messages when inputs are missi…
Browse files Browse the repository at this point in the history
…ng, refine shift creation
  • Loading branch information
harrypuuter committed Sep 7, 2023
1 parent 4192c03 commit 7fe03a8
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 7 deletions.
12 changes: 12 additions & 0 deletions code_generation/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ def __init__(self, scope: str, outputs: Union[Set[Quantity], List[Quantity]]):
super().__init__(self.message)


class InvalidInputError(ConfigurationError):
"""
Exception raised when the list of avialable inputs does not cover all quantities required.
"""

def __init__(self, scope: str, outputs: Union[Set[str], List[str]]):
self.message = "The required inputs {} for the scope '{}' are not provided by any inputfile or producer \n Please check the error message above to find all misconfigured producers".format(
outputs, scope
)
super().__init__(self.message)


class ScopeConfigurationError(ConfigurationError):
"""
Exception raised when the scope configuration provided by the user is not valid.
Expand Down
75 changes: 68 additions & 7 deletions code_generation/friend_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ConfigurationError,
InvalidOutputError,
InsufficientShiftInformationError,
InvalidInputError,
)
from code_generation.producer import Producer, ProducerGroup
from code_generation.rules import ProducerRule
Expand Down Expand Up @@ -266,6 +267,7 @@ def optimize(self) -> None:
self._apply_rules()
self._add_requested_shifts()
self._remove_empty_scopes()
self._validate_inputs()

def _add_requested_shifts(self) -> None:
# first shift the output quantities
Expand All @@ -274,13 +276,25 @@ def _add_requested_shifts(self) -> None:
if shift != "nominal":
shiftname = "__" + shift
for producer in self.producers[scope]:
log.debug("Adding shift %s to producer %s", shift, producer)
producer.shift(shiftname, scope)
# second step is to shift the inputs of the producer
self._shift_producer_inputs(producer, shift, scope)
self._shift_producer_inputs(producer, shift, shiftname, scope)
self.shifts[scope][shiftname] = {}

def _shift_producer_inputs(self, producer, shift, scope):
def _shift_producer_inputs(
self,
producer: Union[Producer, ProducerGroup],
shift: str,
shiftname: str,
scope: str,
) -> None:
"""Function used to determine which inputs of a producer have to be shifted. If none of the inputs of a producer is available in the shift_quantities_map, the producer is skipped.
Args:
producer (Union[Producer, ProducerGroup]): The producer to be checked and possibly shifted
shift (str): the shift to be added
shiftname (str): the name of the shift to be added
scope (str): The scope to be checked
"""
log.debug("Shifting inputs of producer %s", producer)
# if the producer is not of Type ProducerGroup we can directly shift the inputs
if isinstance(producer, Producer):
Expand All @@ -292,11 +306,20 @@ def _shift_producer_inputs(self, producer, shift, scope):
for input in inputs:
if input.name in self.input_quantities_mapping[scope][shift]:
inputs_to_shift.append(input)
log.debug(f"Shifting inputs {inputs_to_shift} of producer {producer}")
producer.shift_inputs("__" + shift, scope, inputs_to_shift)
if len(inputs_to_shift) > 0:
log.debug("Adding shift %s to producer %s", shift, producer)
producer.shift(shiftname, scope)
log.debug(
f"Shifting inputs {inputs_to_shift} of producer {producer} by {shift}"
)
producer.shift_inputs(shiftname, scope, inputs_to_shift)
else:
log.info(
f"no inputs to shift for producer {producer} and shift {shift}, skipping"
)
elif isinstance(producer, ProducerGroup):
for producer in producer.producers[scope]:
self._shift_producer_inputs(producer, shift, scope)
self._shift_producer_inputs(producer, shift, shiftname, scope)

def _validate_outputs(self) -> None:
"""
Expand All @@ -317,6 +340,44 @@ def _validate_outputs(self) -> None:
if len(missing_outputs) > 0:
raise InvalidOutputError(scope, missing_outputs)

def _validate_inputs(self) -> None:
"""
The `_validate_inputs` function checks if all required inputs for each producer in the given scopes
are available, and raises an error if any inputs are missing.
"""

for scope in [scope for scope in self.scopes]:
# get all inputs of all producers
required_inputs = set()
available_inputs = set()
for producer in self.producers[scope]:
required_inputs = required_inputs | set(
[x.name for x in producer.get_inputs(scope)]
)
available_inputs = available_inputs | set(
[x.name for x in producer.get_outputs(scope)]
)
# get all available inputs
for input in self.input_quantities_mapping[scope][""]:
available_inputs.add(input)
# now check if all inputs are available
missing_inputs = required_inputs - available_inputs
if len(missing_inputs) > 0:
for producer in self.producers[scope]:
if (
len(
missing_inputs
& set([x.name for x in producer.get_inputs(scope)])
)
> 0
):
log.error(f"Missing inputs for {producer}")
log.error(f"| Producer inputs: {producer.get_inputs(scope)}")
log.error(
f"| Missing inputs: {missing_inputs & set([ x.name for x in producer.get_inputs(scope)])}"
)
raise InvalidInputError(scope, missing_inputs)

def add_modification_rule(
self, scopes: Union[str, List[str]], rule: ProducerRule
) -> None:
Expand Down

0 comments on commit 7fe03a8

Please sign in to comment.