@@ -464,33 +464,38 @@ def __call__(self, values: dict[Domain, float | int], method="cog") -> np.floati
464
464
"""Calculate the infered value based on different methods.
465
465
Default is center of gravity (cog).
466
466
"""
467
- assert len (args ) == max (
468
- len (c ) for c in self .conditions .keys ()
469
- ), "Number of values must correspond to the number of domains defined as conditions!"
470
- assert isinstance (args , dict ), "Please make sure to pass in the values as a dictionary."
467
+ assert isinstance (values , dict ), "Please make sure to pass a dict[Domain, float|int] as values."
468
+ assert len (self .conditions ) > 0 , "No point in having a rule with no conditions, is there?"
471
469
match method :
472
470
case "cog" :
473
- assert (
474
- len ({C .domain for C in self .conditions .values ()}) == 1
475
- ), "For CoG, all conditions must have the same target domain."
476
- actual_values = {f : f (args [f .domain ]) for S in self .conditions .keys () for f in S }
477
-
478
- weights = []
479
- for K , v in self .conditions .items ():
480
- x = min ((actual_values [k ] for k in K if k in actual_values ), default = 0 )
471
+ # iterate over the conditions and calculate the actual values and weights contributing to cog
472
+ target_weights : list [tuple [Set , Number ]] = []
473
+ target_domain = list (self .conditions .values ())[0 ].domain
474
+ assert target_domain is not None , "Target domain must be defined."
475
+ for if_sets , then_set in self .conditions .items ():
476
+ actual_values : list [Number ] = []
477
+ assert then_set .domain == target_domain , "All target sets must be in the same Domain."
478
+ for s in if_sets :
479
+ assert s .domain is not None , "Domains must be defined."
480
+ actual_values .append (s (values [s .domain ]))
481
+ x = min (actual_values , default = 0 )
481
482
if x > 0 :
482
- weights .append ((v , x ))
483
-
484
- if not weights :
483
+ target_weights .append ((then_set , x ))
484
+ if not target_weights :
485
485
return None
486
- target_domain = list (self .conditions .values ())[0 ].domain
487
- index = sum (v .center_of_gravity * x for v , x in weights ) / sum (x for v , x in weights )
486
+ sum_weights = 0
487
+ sum_weighted_cogs = 0
488
+ for then_set , weight in target_weights :
489
+ sum_weighted_cogs += then_set .center_of_gravity () * weight
490
+ sum_weights += weight
491
+ index = sum_weighted_cogs / sum_weights
492
+
488
493
return (target_domain ._high - target_domain ._low ) / len (
489
494
target_domain .range
490
495
) * index + target_domain ._low
491
496
492
- case "centroid" :
493
- raise NotImplementedError ("Centroid method not implemented yet ." )
497
+ case "centroid" : # centroid == center of mass == center of gravity for simple solids
498
+ raise NotImplementedError ("actually the same as 'cog' if densities are uniform ." )
494
499
case "bisector" :
495
500
raise NotImplementedError ("Bisector method not implemented yet." )
496
501
case "mom" :
@@ -529,7 +534,7 @@ def rule_from_table(table: str, references: dict):
529
534
): eval (df .iloc [x , y ], references ) # type: ignore
530
535
for x , y in product (range (len (df .index )), range (len (df .columns )))
531
536
}
532
- return Rule (D )
537
+ return Rule (D ) # type: ignore
533
538
534
539
535
540
if __name__ == "__main__" :
0 commit comments