"""Constraint handling for generator expansion.
This module provides constraint evaluation for filtering generated combinations
based on mutual exclusion (_mutex_) and dependency requirements (_requires_).
Constraint Types:
_mutex_: Mutual exclusion - certain items cannot appear together
_requires_: Dependencies - if item A is selected, item B must also be
Usage:
# Items A and B cannot appear together in the same combination
{"_or_": ["A", "B", "C", "D"], "pick": 2, "_mutex_": [["A", "B"]]}
# If A is selected, B must also be selected
{"_or_": ["A", "B", "C", "D"], "pick": 2, "_requires_": [["A", "B"]]}
# Complex constraints
{"_or_": ["A", "B", "C", "D"], "pick": 3,
"_mutex_": [["A", "C"]],
"_requires_": [["B", "D"]]}
"""
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
[docs]
def apply_mutex_constraint(
combinations: List[List[Any]],
mutex_groups: List[List[Any]]
) -> List[List[Any]]:
"""Filter combinations that violate mutual exclusion constraints.
A mutex constraint [A, B] means A and B cannot both be present
in the same combination.
Args:
combinations: List of combinations to filter.
mutex_groups: List of mutex groups. Each group is a list of items
that cannot appear together in the same combination.
Returns:
Filtered list of combinations that satisfy all mutex constraints.
Examples:
>>> combos = [["A", "B"], ["A", "C"], ["B", "C"]]
>>> apply_mutex_constraint(combos, [["A", "B"]])
[['A', 'C'], ['B', 'C']]
>>> apply_mutex_constraint(combos, [["A", "B"], ["B", "C"]])
[['A', 'C']]
"""
if not mutex_groups:
return combinations
result = []
for combo in combinations:
if _satisfies_mutex(combo, mutex_groups):
result.append(combo)
return result
def _satisfies_mutex(
combo: List[Any],
mutex_groups: List[List[Any]]
) -> bool:
"""Check if a combination satisfies all mutex constraints.
Args:
combo: A single combination to check.
mutex_groups: List of mutex groups.
Returns:
True if combo satisfies all mutex constraints.
"""
combo_set = set(_normalize_item(item) for item in combo)
for mutex_group in mutex_groups:
mutex_set = set(_normalize_item(item) for item in mutex_group)
# If all items from mutex group are in combo, it violates the constraint
if mutex_set.issubset(combo_set):
return False
return True
[docs]
def apply_requires_constraint(
combinations: List[List[Any]],
requires_groups: List[List[Any]]
) -> List[List[Any]]:
"""Filter combinations that violate dependency requirements.
A requires constraint [A, B] means if A is present, B must also be present.
This is a one-directional dependency from A to B.
Args:
combinations: List of combinations to filter.
requires_groups: List of requirement pairs. Each pair [A, B] means
if A is selected, B must also be selected.
Returns:
Filtered list of combinations that satisfy all requires constraints.
Examples:
>>> combos = [["A", "B"], ["A", "C"], ["B", "C"]]
>>> apply_requires_constraint(combos, [["A", "B"]])
[['A', 'B'], ['B', 'C']] # "A, C" removed because A requires B
>>> # B and C without A is OK because no constraint on B or C
"""
if not requires_groups:
return combinations
result = []
for combo in combinations:
if _satisfies_requires(combo, requires_groups):
result.append(combo)
return result
def _satisfies_requires(
combo: List[Any],
requires_groups: List[List[Any]]
) -> bool:
"""Check if a combination satisfies all requires constraints.
Args:
combo: A single combination to check.
requires_groups: List of requirement pairs [A, B] where A requires B.
Returns:
True if combo satisfies all requires constraints.
"""
combo_set = set(_normalize_item(item) for item in combo)
for requires_pair in requires_groups:
if len(requires_pair) < 2:
continue
# First item requires subsequent items
trigger = _normalize_item(requires_pair[0])
required = set(_normalize_item(item) for item in requires_pair[1:])
if trigger in combo_set:
# Trigger is present, check if all required items are present
if not required.issubset(combo_set):
return False
return True
[docs]
def apply_exclude_constraint(
combinations: List[List[Any]],
exclude_combos: List[List[Any]]
) -> List[List[Any]]:
"""Filter specific combinations from results.
Args:
combinations: List of combinations to filter.
exclude_combos: Specific combinations to exclude.
Returns:
Filtered list excluding specified combinations.
Examples:
>>> combos = [["A", "B"], ["A", "C"], ["B", "C"]]
>>> apply_exclude_constraint(combos, [["A", "B"]])
[['A', 'C'], ['B', 'C']]
"""
if not exclude_combos:
return combinations
# Normalize exclude patterns for comparison
exclude_normalized = [
frozenset(_normalize_item(item) for item in exc)
for exc in exclude_combos
]
result = []
for combo in combinations:
combo_normalized = frozenset(_normalize_item(item) for item in combo)
if combo_normalized not in exclude_normalized:
result.append(combo)
return result
[docs]
def apply_all_constraints(
combinations: List[List[Any]],
mutex_groups: Optional[List[List[Any]]] = None,
requires_groups: Optional[List[List[Any]]] = None,
exclude_combos: Optional[List[List[Any]]] = None
) -> List[List[Any]]:
"""Apply all constraints in sequence.
Args:
combinations: List of combinations to filter.
mutex_groups: Mutual exclusion groups.
requires_groups: Dependency requirement pairs.
exclude_combos: Specific combinations to exclude.
Returns:
Filtered list satisfying all constraints.
"""
result = combinations
if mutex_groups:
result = apply_mutex_constraint(result, mutex_groups)
if requires_groups:
result = apply_requires_constraint(result, requires_groups)
if exclude_combos:
result = apply_exclude_constraint(result, exclude_combos)
return result
def count_with_constraints(
n: int,
k: int,
mutex_groups: Optional[List[List[Any]]] = None,
requires_groups: Optional[List[List[Any]]] = None
) -> int:
"""Estimate count of combinations after constraint filtering.
Note: This is an approximation. For exact count, generate and filter.
Args:
n: Total number of items.
k: Size of combinations to select.
mutex_groups: Mutual exclusion groups.
requires_groups: Dependency requirement pairs.
Returns:
Estimated count of valid combinations.
"""
from math import comb
base_count = comb(n, k)
if not mutex_groups and not requires_groups:
return base_count
# For complex constraints, we'd need to use inclusion-exclusion
# This is a simplified estimate
# For now, we return the base count as an upper bound
return base_count
def _normalize_item(item: Any) -> Any:
"""Normalize an item for comparison.
Converts dicts to frozensets of items and lists to tuples
so they can be used in set operations.
Args:
item: Item to normalize.
Returns:
Hashable normalized item.
"""
if isinstance(item, dict):
# Convert dict to frozenset of (key, normalized_value) tuples
return frozenset((k, _normalize_item(v)) for k, v in item.items())
elif isinstance(item, list):
return tuple(_normalize_item(i) for i in item)
elif isinstance(item, set):
return frozenset(_normalize_item(i) for i in item)
else:
return item
[docs]
def parse_constraints(node: Dict[str, Any]) -> Dict[str, List[List[Any]]]:
"""Extract constraint specifications from a node.
Args:
node: Node containing constraint keywords.
Returns:
Dict with 'mutex', 'requires', 'exclude' lists.
"""
return {
'mutex': node.get('_mutex_', []),
'requires': node.get('_requires_', []),
'exclude': node.get('_exclude_', []),
}
[docs]
def validate_constraints(
constraints: Dict[str, List[List[Any]]],
choices: List[Any]
) -> List[str]:
"""Validate constraint specifications against available choices.
Args:
constraints: Constraint dict from parse_constraints.
choices: Available choice items.
Returns:
List of validation error messages.
"""
errors = []
normalized_choices = set(_normalize_item(c) for c in choices)
# Validate mutex groups
for i, group in enumerate(constraints.get('mutex', [])):
if not isinstance(group, list):
errors.append(f"_mutex_[{i}] must be a list")
continue
for j, item in enumerate(group):
if _normalize_item(item) not in normalized_choices:
errors.append(f"_mutex_[{i}][{j}]: '{item}' not in choices")
# Validate requires groups
for i, group in enumerate(constraints.get('requires', [])):
if not isinstance(group, list):
errors.append(f"_requires_[{i}] must be a list")
continue
if len(group) < 2:
errors.append(f"_requires_[{i}] must have at least 2 items")
continue
for j, item in enumerate(group):
if _normalize_item(item) not in normalized_choices:
errors.append(f"_requires_[{i}][{j}]: '{item}' not in choices")
# Validate exclude patterns
for i, pattern in enumerate(constraints.get('exclude', [])):
if not isinstance(pattern, list):
errors.append(f"_exclude_[{i}] must be a list")
continue
for j, item in enumerate(pattern):
if _normalize_item(item) not in normalized_choices:
errors.append(f"_exclude_[{i}][{j}]: '{item}' not in choices")
return errors