Skip to content

Commit 858b497

Browse files
committed
chore: fixed typing
1 parent 2b1d79c commit 858b497

File tree

2 files changed

+141
-110
lines changed

2 files changed

+141
-110
lines changed

src/fuzzylogic/classes.py

Lines changed: 94 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def plt(*args, **kwargs):
2222
from .combinators import MAX, MIN, bounded_sum, product, simple_disjoint_sum
2323
from .functions import inv, normalize
2424

25+
type Number = int | float
26+
2527

2628
class FuzzyWarning(UserWarning):
2729
"""Extra Exception so that user code can filter exceptions specific to this lib."""
@@ -49,7 +51,7 @@ class Domain:
4951
It is possible now to call derived sets without assignment first!
5052
>>> from .hedges import very
5153
>>> (very(~temp.hot) | ~very(temp.hot))(2)
52-
1
54+
1.0
5355
5456
You MUST NOT add arbitrary attributes to an *instance* of Domain - you can
5557
however subclass or modify the class itself. If you REALLY have to add attributes,
@@ -65,14 +67,16 @@ class Domain:
6567
def __init__(
6668
self,
6769
name: str,
68-
low: float | int,
69-
high: float | int,
70-
res: float | int = 1,
70+
low: Number,
71+
high: Number,
72+
res: Number = 1,
7173
sets: dict | None = None,
7274
) -> None:
7375
"""Define a domain."""
7476
assert low < high, "higher bound must be greater than lower."
7577
assert res > 0, "resolution can't be negative or zero"
78+
assert isinstance(name, str), "Name must be a string."
79+
assert str.isidentifier(name), "Name must be a valid identifier."
7680
self._name = name
7781
self._high = high
7882
self._low = low
@@ -85,6 +89,10 @@ def __call__(self, x):
8589
raise FuzzyWarning(f"{x} is outside of domain!")
8690
return {name: s.func(x) for name, s in self._sets.items()}
8791

92+
def __len__(self):
93+
"""Return the size of the domain, as the actual number of possible values, calculated internally."""
94+
return len(self.range)
95+
8896
def __str__(self):
8997
"""Return a string to print()."""
9098
return self._name
@@ -95,15 +103,13 @@ def __repr__(self):
95103

96104
def __eq__(self, other):
97105
"""Test equality of two domains."""
98-
return all(
99-
[
100-
self._name == other._name,
101-
self._low == other._low,
102-
self._high == other._high,
103-
self._res == other._res,
104-
self._sets == other._sets,
105-
]
106-
)
106+
return all([
107+
self._name == other._name,
108+
self._low == other._low,
109+
self._high == other._high,
110+
self._res == other._res,
111+
self._sets == other._sets,
112+
])
107113

108114
def __hash__(self):
109115
return id(self)
@@ -188,57 +194,71 @@ class Set:
188194
189195
"""
190196

197+
type T = Set
191198
name = None # these are set on assignment to the domain! DO NOT MODIFY
192199
domain = None
193200

194-
def __init__(self, func: Callable, *, name: str | None = None, domain: Domain | None = None):
195-
self.func = func
196-
self.domain = domain
197-
self.name = name
198-
self.__center_of_gravity = None
199-
200-
def __call__(self, x):
201-
return self.func(x)
201+
def __init__(
202+
self,
203+
func: Callable[..., Number],
204+
*,
205+
name: str | None = None,
206+
domain: Domain | None = None,
207+
):
208+
self.func: Callable[..., Number] = func
209+
self.domain: Domain | None = domain
210+
self.name: str | None = name
211+
self.__center_of_gravity: np.floating | None = None
212+
213+
def __call__(self, x: Number | np.ndarray) -> Number | np.ndarray:
214+
if isinstance(x, np.ndarray):
215+
return np.array([self.func(v) for v in x])
216+
else:
217+
return self.func(x)
202218

203-
def __invert__(self):
219+
def __invert__(self) -> T:
204220
"""Return a new set with 1 - function."""
205221
return Set(inv(self.func), domain=self.domain)
206222

207-
def __neg__(self):
223+
def __neg__(self) -> T:
208224
"""Synonyme for invert."""
209225
return Set(inv(self.func), domain=self.domain)
210226

211-
def __and__(self, other):
227+
def __and__(self, other: T) -> T:
212228
"""Return a new set with modified function."""
213229
assert self.domain == other.domain
214230
return Set(MIN(self.func, other.func), domain=self.domain)
215231

216-
def __or__(self, other):
232+
def __or__(self, other: T) -> T:
217233
"""Return a new set with modified function."""
218234
assert self.domain == other.domain
219235
return Set(MAX(self.func, other.func), domain=self.domain)
220236

221-
def __mul__(self, other):
237+
def __mul__(self, other: T) -> T:
222238
"""Return a new set with modified function."""
223239
assert self.domain == other.domain
224240
return Set(product(self.func, other.func), domain=self.domain)
225241

226-
def __add__(self, other):
242+
def __add__(self, other: T) -> T:
227243
"""Return a new set with modified function."""
228244
assert self.domain == other.domain
229245
return Set(bounded_sum(self.func, other.func), domain=self.domain)
230246

231-
def __xor__(self, other):
247+
def __xor__(self, other: T) -> T:
232248
"""Return a new set with modified function."""
233249
assert self.domain == other.domain
234250
return Set(simple_disjoint_sum(self.func, other.func), domain=self.domain)
235251

236-
def __pow__(self, power):
252+
def __pow__(self, power: int) -> T:
237253
"""Return a new set with modified function."""
254+
238255
# FYI: pow is used with hedges
239-
return Set(lambda x: pow(self.func(x), power), domain=self.domain)
256+
def f(x: float):
257+
return pow(self.func(x), power) # TODO: test this
240258

241-
def __eq__(self, other):
259+
return Set(f, domain=self.domain)
260+
261+
def __eq__(self, other: T) -> bool:
242262
"""A set is equal with another if both return the same values over the same range."""
243263
if self.domain is None or other.domain is None:
244264
# It would require complete AST analysis to check whether both Sets
@@ -251,49 +271,49 @@ def __eq__(self, other):
251271
# we simply can check if they map to the same values
252272
return np.array_equal(self.array(), other.array())
253273

254-
def __le__(self, other):
274+
def __le__(self, other: T) -> bool:
255275
"""If this <= other, it means this is a subset of the other."""
256276
assert self.domain == other.domain
257277
if self.domain is None or other.domain is None:
258278
raise FuzzyWarning("Can't compare without Domains.")
259279
return all(np.less_equal(self.array(), other.array()))
260280

261-
def __lt__(self, other):
281+
def __lt__(self, other: T) -> bool:
262282
"""If this < other, it means this is a proper subset of the other."""
263283
assert self.domain == other.domain
264284
if self.domain is None or other.domain is None:
265285
raise FuzzyWarning("Can't compare without Domains.")
266286
return all(np.less(self.array(), other.array()))
267287

268-
def __ge__(self, other):
288+
def __ge__(self, other: T) -> bool:
269289
"""If this >= other, it means this is a superset of the other."""
270290
assert self.domain == other.domain
271291
if self.domain is None or other.domain is None:
272292
raise FuzzyWarning("Can't compare without Domains.")
273293
return all(np.greater_equal(self.array(), other.array()))
274294

275-
def __gt__(self, other):
295+
def __gt__(self, other: T) -> bool:
276296
"""If this > other, it means this is a proper superset of the other."""
277297
assert self.domain == other.domain
278298
if self.domain is None or other.domain is None:
279299
raise FuzzyWarning("Can't compare without Domains.")
280300
return all(np.greater(self.array(), other.array()))
281301

282-
def __len__(self):
302+
def __len__(self) -> int:
283303
"""Number of membership values in the set, defined by bounds and resolution of domain."""
284304
if self.domain is None:
285305
raise FuzzyWarning("No domain.")
286306
return len(self.array())
287307

288308
@property
289-
def cardinality(self):
309+
def cardinality(self) -> int:
290310
"""The sum of all values in the set."""
291311
if self.domain is None:
292312
raise FuzzyWarning("No domain.")
293313
return sum(self.array())
294314

295315
@property
296-
def relative_cardinality(self):
316+
def relative_cardinality(self) -> np.floating | float:
297317
"""Relative cardinality is the sum of all membership values by number of all values."""
298318
if self.domain is None:
299319
raise FuzzyWarning("No domain.")
@@ -302,7 +322,7 @@ def relative_cardinality(self):
302322
raise FuzzyWarning("The domain has no element.")
303323
return self.cardinality / len(self)
304324

305-
def concentrated(self):
325+
def concentrated(self) -> T:
306326
"""
307327
Alternative to hedge "very".
308328
@@ -311,7 +331,7 @@ def concentrated(self):
311331
"""
312332
return Set(lambda x: self.func(x) ** 2, domain=self.domain)
313333

314-
def intensified(self):
334+
def intensified(self) -> T:
315335
"""
316336
Alternative to hedges.
317337
@@ -324,11 +344,11 @@ def f(x):
324344

325345
return Set(f, domain=self.domain)
326346

327-
def dilated(self):
347+
def dilated(self) -> T:
328348
"""Expand the set with more values and already included values are enhanced."""
329349
return Set(lambda x: self.func(x) ** 1.0 / 2.0, domain=self.domain)
330350

331-
def multiplied(self, n):
351+
def multiplied(self, n) -> T:
332352
"""Multiply with a constant factor, changing all membership values."""
333353
return Set(lambda x: self.func(x) * n, domain=self.domain)
334354

@@ -340,15 +360,22 @@ def plot(self):
340360
V = [self.func(x) for x in R]
341361
plt.plot(R, V)
342362

343-
def array(self):
363+
def array(self) -> np.ndarray:
344364
"""Return an array of all values for this set within the given domain."""
345365
if self.domain is None:
346366
raise FuzzyWarning("No domain assigned.")
347367
return np.fromiter((self.func(x) for x in self.domain.range), float)
348368

349-
def center_of_gravity(self):
350-
"""Return the center of gravity for this distribution, within the given domain."""
369+
def range(self) -> np.ndarray:
370+
"""Return the range of the domain."""
371+
if self.domain is None:
372+
raise FuzzyWarning("No domain assigned.")
373+
return self.domain.range
351374

375+
def center_of_gravity(self) -> np.floating | float:
376+
"""Return the center of gravity for this distribution, within the given domain."""
377+
if self.__center_of_gravity is not None:
378+
return self.__center_of_gravity
352379
assert self.domain is not None, "No center of gravity with no domain."
353380
weights = self.array()
354381
if sum(weights) == 0:
@@ -357,7 +384,7 @@ def center_of_gravity(self):
357384
self.__center_of_gravity = cog
358385
return cog
359386

360-
def __repr__(self):
387+
def __repr__(self) -> str:
361388
"""
362389
Return a string representation of the Set that reconstructs the set with eval().
363390
@@ -377,9 +404,11 @@ def create_function_closure():
377404
func = types.FunctionType(*args[:-1] + [closure])
378405
return func
379406
"""
380-
return f"Set({self.func})"
407+
if self.domain is not None:
408+
return f"{self.domain._name}."
409+
return f"Set({__name__}({self.func.__qualname__})"
381410

382-
def __str__(self):
411+
def __str__(self) -> str:
383412
"""Return a string for print()."""
384413
if self.domain is not None:
385414
return f"{self.domain._name}.{self.name}"
@@ -388,48 +417,50 @@ def __str__(self):
388417
else:
389418
return f"dangling Set({self.name}"
390419

391-
def normalized(self):
420+
def normalized(self) -> T:
392421
"""Return a set that is normalized *for this domain* with 1 as max."""
393422
if self.domain is None:
394423
raise FuzzyWarning("Can't normalize without domain.")
395424
return Set(normalize(max(self.array()), self.func), domain=self.domain)
396425

397-
def __hash__(self):
426+
def __hash__(self) -> int:
398427
return id(self)
399428

400429

401430
class Rule:
402431
"""
403-
A collection of bound sets that span a multi-dimensional space of their respective domains.
432+
Collection of bound sets spanning a multi-dimensional space of their domains, mapping to a target domain.
433+
404434
"""
405435

406-
def __init__(self, conditions, func=None):
407-
print("ohalala")
408-
self.conditions = {frozenset(C): oth for C, oth in conditions.items()}
409-
self.func = func
436+
type T = Rule
437+
438+
def __init__(self, conditions_in: dict[Iterable[Set] | Set, Set]):
439+
self.conditions: dict[frozenset[Set], Set] = {}
440+
for if_sets, then_set in conditions_in.items():
441+
if isinstance(if_sets, Set):
442+
if_sets = (if_sets,)
443+
self.conditions[frozenset(if_sets)] = then_set
410444

411-
def __add__(self, other):
412-
assert isinstance(other, Rule)
445+
def __add__(self, other: T):
413446
return Rule({**self.conditions, **other.conditions})
414447

415-
def __radd__(self, other):
416-
assert isinstance(other, (Rule, int))
448+
def __radd__(self, other: T | int) -> T:
417449
# we're using sum(..)
418450
if isinstance(other, int):
419451
return self
420452
return Rule({**self.conditions, **other.conditions})
421453

422-
def __or__(self, other):
423-
assert isinstance(other, Rule)
454+
def __or__(self, other: T):
424455
return Rule({**self.conditions, **other.conditions})
425456

426-
def __eq__(self, other):
457+
def __eq__(self, other: T):
427458
return self.conditions == other.conditions
428459

429460
def __getitem__(self, key):
430461
return self.conditions[frozenset(key)]
431462

432-
def __call__(self, args: "dict[Domain, float]", method="cog"):
463+
def __call__(self, values: dict[Domain, float | int], method="cog") -> np.floating | float | None:
433464
"""Calculate the infered value based on different methods.
434465
Default is center of gravity (cog).
435466
"""

0 commit comments

Comments
 (0)