Skip to content

Commit 8457e50

Browse files
authored
[mypyc] Support top level function ops via CallC (#8902)
Relates to mypyc/mypyc#709 This PR supports top-level function ops via recently added CallC IR. To demonstrate the idea, it transform to_list op from PrimitiveOp to CallC. It also refines CallC with arguments coercing and support of steals.
1 parent 273a865 commit 8457e50

File tree

7 files changed

+119
-26
lines changed

7 files changed

+119
-26
lines changed

mypyc/ir/ops.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,13 +1145,19 @@ class CallC(RegisterOp):
11451145
A call to a C function
11461146
"""
11471147

1148-
error_kind = ERR_MAGIC
1149-
1150-
def __init__(self, function_name: str, args: List[Value], ret_type: RType, line: int) -> None:
1148+
def __init__(self,
1149+
function_name: str,
1150+
args: List[Value],
1151+
ret_type: RType,
1152+
steals: StealsDescription,
1153+
error_kind: int,
1154+
line: int) -> None:
1155+
self.error_kind = error_kind
11511156
super().__init__(line)
11521157
self.function_name = function_name
11531158
self.args = args
11541159
self.type = ret_type
1160+
self.steals = steals
11551161

11561162
def to_str(self, env: Environment) -> str:
11571163
args_str = ', '.join(env.format('%r', arg) for arg in self.args)
@@ -1160,6 +1166,13 @@ def to_str(self, env: Environment) -> str:
11601166
def sources(self) -> List[Value]:
11611167
return self.args
11621168

1169+
def stolen(self) -> List[Value]:
1170+
if isinstance(self.steals, list):
1171+
assert len(self.steals) == len(self.args)
1172+
return [arg for arg, steal in zip(self.args, self.steals) if steal]
1173+
else:
1174+
return [] if not self.steals else self.sources()
1175+
11631176
def accept(self, visitor: 'OpVisitor[T]') -> T:
11641177
return visitor.visit_call_c(self)
11651178

mypyc/irbuild/builder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
)
4444
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
4545
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
46-
from mypyc.primitives.registry import func_ops
46+
from mypyc.primitives.registry import func_ops, CFunctionDescription, c_function_ops
4747
from mypyc.primitives.list_ops import list_len_op, to_list, list_pop_last
4848
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
4949
from mypyc.primitives.generic_ops import py_setattr_op, iter_op, next_op
@@ -229,6 +229,9 @@ def gen_method_call(self,
229229
def load_module(self, name: str) -> Value:
230230
return self.builder.load_module(name)
231231

232+
def call_c(self, desc: CFunctionDescription, args: List[Value], line: int) -> Value:
233+
return self.builder.call_c(desc, args, line)
234+
232235
@property
233236
def environment(self) -> Environment:
234237
return self.builder.environment
@@ -498,7 +501,7 @@ def process_iterator_tuple_assignment(self,
498501
# Assign the starred value and all values after it
499502
if target.star_idx is not None:
500503
post_star_vals = target.items[split_idx + 1:]
501-
iter_list = self.primitive_op(to_list, [iterator], line)
504+
iter_list = self.call_c(to_list, [iterator], line)
502505
iter_list_len = self.primitive_op(list_len_op, [iter_list], line)
503506
post_star_len = self.add(LoadInt(len(post_star_vals)))
504507
condition = self.binary_op(post_star_len, iter_list_len, '<=', line)
@@ -715,6 +718,11 @@ def call_refexpr_with_args(
715718

716719
# Handle data-driven special-cased primitive call ops.
717720
if callee.fullname is not None and expr.arg_kinds == [ARG_POS] * len(arg_values):
721+
call_c_ops_candidates = c_function_ops.get(callee.fullname, [])
722+
target = self.builder.matching_call_c(call_c_ops_candidates, arg_values,
723+
expr.line, self.node_type(expr))
724+
if target:
725+
return target
718726
ops = func_ops.get(callee.fullname, [])
719727
target = self.builder.matching_primitive_op(
720728
ops, arg_values, expr.line, self.node_type(expr)

mypyc/irbuild/ll_builder.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from mypyc.ir.rtypes import (
2727
RType, RUnion, RInstance, optional_value_type, int_rprimitive, float_rprimitive,
2828
bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive,
29-
void_rtype
3029
)
3130
from mypyc.ir.func_ir import FuncDecl, FuncSignature
3231
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
@@ -35,7 +34,7 @@
3534
)
3635
from mypyc.primitives.registry import (
3736
binary_ops, unary_ops, method_ops, func_ops,
38-
c_method_call_ops, CFunctionDescription
37+
c_method_call_ops, CFunctionDescription, c_function_ops
3938
)
4039
from mypyc.primitives.list_ops import (
4140
list_extend_op, list_len_op, new_list_op
@@ -592,6 +591,10 @@ def builtin_call(self,
592591
args: List[Value],
593592
fn_op: str,
594593
line: int) -> Value:
594+
call_c_ops_candidates = c_function_ops.get(fn_op, [])
595+
target = self.matching_call_c(call_c_ops_candidates, args, line)
596+
if target:
597+
return target
595598
ops = func_ops.get(fn_op, [])
596599
target = self.matching_primitive_op(ops, args, line)
597600
assert target, 'Unsupported builtin function: %s' % fn_op
@@ -667,13 +670,25 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
667670
self.add(Branch(value, true, false, Branch.BOOL_EXPR))
668671

669672
def call_c(self,
670-
function_name: str,
673+
desc: CFunctionDescription,
671674
args: List[Value],
672675
line: int,
673-
result_type: Optional[RType]) -> Value:
676+
result_type: Optional[RType] = None) -> Value:
674677
# handle void function via singleton RVoid instance
675-
ret_type = void_rtype if result_type is None else result_type
676-
target = self.add(CallC(function_name, args, ret_type, line))
678+
coerced = []
679+
for i, arg in enumerate(args):
680+
formal_type = desc.arg_types[i]
681+
arg = self.coerce(arg, formal_type, line)
682+
coerced.append(arg)
683+
target = self.add(CallC(desc.c_function_name, coerced, desc.return_type, desc.steals,
684+
desc.error_kind, line))
685+
if result_type and not is_runtime_subtype(target.type, result_type):
686+
if is_none_rprimitive(result_type):
687+
# Special case None return. The actual result may actually be a bool
688+
# and so we can't just coerce it.
689+
target = self.none()
690+
else:
691+
target = self.coerce(target, result_type, line)
677692
return target
678693

679694
def matching_call_c(self,
@@ -697,7 +712,7 @@ def matching_call_c(self,
697712
else:
698713
matching = desc
699714
if matching:
700-
target = self.call_c(matching.c_function_name, args, line, result_type)
715+
target = self.call_c(matching, args, line, result_type)
701716
return target
702717
return None
703718

@@ -786,8 +801,8 @@ def translate_special_method_call(self,
786801
"""
787802
ops = method_ops.get(name, [])
788803
call_c_ops_candidates = c_method_call_ops.get(name, [])
789-
call_c_op = self.matching_call_c(call_c_ops_candidates, [base_reg] + args, line,
790-
result_type=result_type)
804+
call_c_op = self.matching_call_c(call_c_ops_candidates, [base_reg] + args,
805+
line, result_type)
791806
if call_c_op is not None:
792807
return call_c_op
793808
return self.matching_primitive_op(ops, [base_reg] + args, line, result_type=result_type)

mypyc/primitives/list_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
from mypyc.primitives.registry import (
1010
name_ref_op, binary_op, func_op, method_op, custom_op, name_emit,
11-
call_emit, call_negative_bool_emit,
11+
call_emit, call_negative_bool_emit, c_function_op
1212
)
1313

1414

@@ -20,12 +20,13 @@
2020
is_borrowed=True)
2121

2222
# list(obj)
23-
to_list = func_op(
23+
to_list = c_function_op(
2424
name='builtins.list',
2525
arg_types=[object_rprimitive],
26-
result_type=list_rprimitive,
26+
return_type=list_rprimitive,
27+
c_function_name='PySequence_List',
2728
error_kind=ERR_MAGIC,
28-
emit=call_emit('PySequence_List'))
29+
)
2930

3031

3132
def emit_new(emitter: EmitterInterface, args: List[str], dest: str) -> None:
@@ -83,7 +84,6 @@ def emit_new(emitter: EmitterInterface, args: List[str], dest: str) -> None:
8384
error_kind=ERR_FALSE,
8485
emit=call_emit('CPyList_SetItem'))
8586

86-
8787
# list.append(obj)
8888
list_append_op = method_op(
8989
name='append',

mypyc/primitives/registry.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@
4545
CFunctionDescription = NamedTuple(
4646
'CFunctionDescription', [('name', str),
4747
('arg_types', List[RType]),
48-
('result_type', Optional[RType]),
48+
('return_type', RType),
4949
('c_function_name', str),
5050
('error_kind', int),
51+
('steals', StealsDescription),
5152
('priority', int)])
5253

5354
# Primitive binary ops (key is operator such as '+')
@@ -65,8 +66,12 @@
6566
# Primitive ops for reading module attributes (key is name such as 'builtins.None')
6667
name_ref_ops = {} # type: Dict[str, OpDescription]
6768

69+
# CallC op for method call(such as 'str.join')
6870
c_method_call_ops = {} # type: Dict[str, List[CFunctionDescription]]
6971

72+
# CallC op for top level function call(such as 'builtins.list')
73+
c_function_ops = {} # type: Dict[str, List[CFunctionDescription]]
74+
7075

7176
def simple_emit(template: str) -> EmitCallback:
7277
"""Construct a simple PrimitiveOp emit callback function.
@@ -323,14 +328,30 @@ def custom_op(arg_types: List[RType],
323328

324329
def c_method_op(name: str,
325330
arg_types: List[RType],
326-
result_type: Optional[RType],
331+
return_type: RType,
327332
c_function_name: str,
328333
error_kind: int,
329-
priority: int = 1) -> None:
334+
steals: StealsDescription = False,
335+
priority: int = 1) -> CFunctionDescription:
330336
ops = c_method_call_ops.setdefault(name, [])
331-
desc = CFunctionDescription(name, arg_types, result_type,
332-
c_function_name, error_kind, priority)
337+
desc = CFunctionDescription(name, arg_types, return_type,
338+
c_function_name, error_kind, steals, priority)
339+
ops.append(desc)
340+
return desc
341+
342+
343+
def c_function_op(name: str,
344+
arg_types: List[RType],
345+
return_type: RType,
346+
c_function_name: str,
347+
error_kind: int,
348+
steals: StealsDescription = False,
349+
priority: int = 1) -> CFunctionDescription:
350+
ops = c_function_ops.setdefault(name, [])
351+
desc = CFunctionDescription(name, arg_types, return_type,
352+
c_function_name, error_kind, steals, priority)
333353
ops.append(desc)
354+
return desc
334355

335356

336357
# Import various modules that set up global state.

mypyc/primitives/str_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
c_method_op(
3838
name='join',
3939
arg_types=[str_rprimitive, object_rprimitive],
40-
result_type=str_rprimitive,
40+
return_type=str_rprimitive,
4141
c_function_name='PyUnicode_Join',
4242
error_kind=ERR_MAGIC
4343
)

mypyc/test-data/irbuild-basic.test

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3383,7 +3383,7 @@ L0:
33833383
r5 = None
33843384
return r5
33853385

3386-
[case testCallCWithStrJoin]
3386+
[case testCallCWithStrJoinMethod]
33873387
from typing import List
33883388
def f(x: str, y: List[str]) -> str:
33893389
return x.join(y)
@@ -3395,3 +3395,39 @@ def f(x, y):
33953395
L0:
33963396
r0 = PyUnicode_Join(x, y)
33973397
return r0
3398+
3399+
[case testCallCWithToListFunction]
3400+
from typing import List, Iterable, Tuple, Dict
3401+
# generic object
3402+
def f(x: Iterable[int]) -> List[int]:
3403+
return list(x)
3404+
3405+
# need coercing
3406+
def g(x: Tuple[int, int, int]) -> List[int]:
3407+
return list(x)
3408+
3409+
# non-list object
3410+
def h(x: Dict[int, str]) -> List[int]:
3411+
return list(x)
3412+
3413+
[out]
3414+
def f(x):
3415+
x :: object
3416+
r0 :: list
3417+
L0:
3418+
r0 = PySequence_List(x)
3419+
return r0
3420+
def g(x):
3421+
x :: tuple[int, int, int]
3422+
r0 :: object
3423+
r1 :: list
3424+
L0:
3425+
r0 = box(tuple[int, int, int], x)
3426+
r1 = PySequence_List(r0)
3427+
return r1
3428+
def h(x):
3429+
x :: dict
3430+
r0 :: list
3431+
L0:
3432+
r0 = PySequence_List(x)
3433+
return r0

0 commit comments

Comments
 (0)