Skip to content

Commit b21afd5

Browse files
committed
Refactor Rule class to improve readability and maintainability
1 parent 7191fba commit b21afd5

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

src/fuzzylogic/classes.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -464,33 +464,38 @@ def __call__(self, values: dict[Domain, float | int], method="cog") -> np.floati
464464
"""Calculate the infered value based on different methods.
465465
Default is center of gravity (cog).
466466
"""
467-
assert len(args) == max(
468-
len(c) for c in self.conditions.keys()
469-
), "Number of values must correspond to the number of domains defined as conditions!"
470-
assert isinstance(args, dict), "Please make sure to pass in the values as a dictionary."
467+
assert isinstance(values, dict), "Please make sure to pass a dict[Domain, float|int] as values."
468+
assert len(self.conditions) > 0, "No point in having a rule with no conditions, is there?"
471469
match method:
472470
case "cog":
473-
assert (
474-
len({C.domain for C in self.conditions.values()}) == 1
475-
), "For CoG, all conditions must have the same target domain."
476-
actual_values = {f: f(args[f.domain]) for S in self.conditions.keys() for f in S}
477-
478-
weights = []
479-
for K, v in self.conditions.items():
480-
x = min((actual_values[k] for k in K if k in actual_values), default=0)
471+
# iterate over the conditions and calculate the actual values and weights contributing to cog
472+
target_weights: list[tuple[Set, Number]] = []
473+
target_domain = list(self.conditions.values())[0].domain
474+
assert target_domain is not None, "Target domain must be defined."
475+
for if_sets, then_set in self.conditions.items():
476+
actual_values: list[Number] = []
477+
assert then_set.domain == target_domain, "All target sets must be in the same Domain."
478+
for s in if_sets:
479+
assert s.domain is not None, "Domains must be defined."
480+
actual_values.append(s(values[s.domain]))
481+
x = min(actual_values, default=0)
481482
if x > 0:
482-
weights.append((v, x))
483-
484-
if not weights:
483+
target_weights.append((then_set, x))
484+
if not target_weights:
485485
return None
486-
target_domain = list(self.conditions.values())[0].domain
487-
index = sum(v.center_of_gravity * x for v, x in weights) / sum(x for v, x in weights)
486+
sum_weights = 0
487+
sum_weighted_cogs = 0
488+
for then_set, weight in target_weights:
489+
sum_weighted_cogs += then_set.center_of_gravity() * weight
490+
sum_weights += weight
491+
index = sum_weighted_cogs / sum_weights
492+
488493
return (target_domain._high - target_domain._low) / len(
489494
target_domain.range
490495
) * index + target_domain._low
491496

492-
case "centroid":
493-
raise NotImplementedError("Centroid method not implemented yet.")
497+
case "centroid": # centroid == center of mass == center of gravity for simple solids
498+
raise NotImplementedError("actually the same as 'cog' if densities are uniform.")
494499
case "bisector":
495500
raise NotImplementedError("Bisector method not implemented yet.")
496501
case "mom":
@@ -529,7 +534,7 @@ def rule_from_table(table: str, references: dict):
529534
): eval(df.iloc[x, y], references) # type: ignore
530535
for x, y in product(range(len(df.index)), range(len(df.columns)))
531536
}
532-
return Rule(D)
537+
return Rule(D) # type: ignore
533538

534539

535540
if __name__ == "__main__":

0 commit comments

Comments
 (0)