@@ -572,7 +572,13 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
572
572
"""
573
573
ws = ""
574
574
# original, if small, if large
575
- conversion_data = (
575
+ conversion_data : tuple [
576
+ tuple [type , type , type ],
577
+ tuple [type , type , type ],
578
+ tuple [type , type , type ],
579
+ tuple [type , type , type ],
580
+ tuple [type , type , type ],
581
+ ] = (
576
582
(np .bool_ , np .int8 , np .int8 ),
577
583
(np .uint8 , np .int8 , np .int16 ),
578
584
(np .uint16 , np .int16 , np .int32 ),
@@ -600,13 +606,12 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
600
606
dtype = data [col ].dtype
601
607
for c_data in conversion_data :
602
608
if dtype == c_data [0 ]:
603
- # Value of type variable "_IntType" of "iinfo" cannot be "object"
604
- if data [col ].max () <= np .iinfo (c_data [1 ]).max : # type: ignore[type-var]
609
+ if data [col ].max () <= np .iinfo (c_data [1 ]).max :
605
610
dtype = c_data [1 ]
606
611
else :
607
612
dtype = c_data [2 ]
608
613
if c_data [2 ] == np .int64 : # Warn if necessary
609
- if data [col ].max () >= 2 ** 53 :
614
+ if data [col ].max () >= 2 ** 53 :
610
615
ws = precision_loss_doc .format ("uint64" , "float64" )
611
616
612
617
data [col ] = data [col ].astype (dtype )
@@ -623,7 +628,7 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
623
628
data [col ] = data [col ].astype (np .int32 )
624
629
else :
625
630
data [col ] = data [col ].astype (np .float64 )
626
- if data [col ].max () >= 2 ** 53 or data [col ].min () <= - (2 ** 53 ):
631
+ if data [col ].max () >= 2 ** 53 or data [col ].min () <= - (2 ** 53 ):
627
632
ws = precision_loss_doc .format ("int64" , "float64" )
628
633
elif dtype in (np .float32 , np .float64 ):
629
634
if np .isinf (data [col ]).any ():
@@ -967,17 +972,15 @@ def __init__(self) -> None:
967
972
(255 , np .dtype (np .float64 )),
968
973
]
969
974
)
970
- self .DTYPE_MAP_XML = {
975
+ self .DTYPE_MAP_XML : dict [ int , np . dtype ] = {
971
976
32768 : np .dtype (np .uint8 ), # Keys to GSO
972
977
65526 : np .dtype (np .float64 ),
973
978
65527 : np .dtype (np .float32 ),
974
979
65528 : np .dtype (np .int32 ),
975
980
65529 : np .dtype (np .int16 ),
976
981
65530 : np .dtype (np .int8 ),
977
982
}
978
- # error: Argument 1 to "list" has incompatible type "str";
979
- # expected "Iterable[int]" [arg-type]
980
- self .TYPE_MAP = list (range (251 )) + list ("bhlfd" ) # type: ignore[arg-type]
983
+ self .TYPE_MAP = list (tuple (range (251 )) + tuple ("bhlfd" ))
981
984
self .TYPE_MAP_XML = {
982
985
# Not really a Q, unclear how to handle byteswap
983
986
32768 : "Q" ,
@@ -1296,9 +1299,7 @@ def g(typ: int) -> str | np.dtype:
1296
1299
if typ <= 2045 :
1297
1300
return str (typ )
1298
1301
try :
1299
- # error: Incompatible return value type (got "Type[number]", expected
1300
- # "Union[str, dtype]")
1301
- return self .DTYPE_MAP_XML [typ ] # type: ignore[return-value]
1302
+ return self .DTYPE_MAP_XML [typ ]
1302
1303
except KeyError as err :
1303
1304
raise ValueError (f"cannot convert stata dtype [{ typ } ]" ) from err
1304
1305
0 commit comments