Skip to content

Commit 7832de0

Browse files
authored
[Compile Time Constant Extraction] Add extraction of initialization calls (#62365)
1 parent c53d2d7 commit 7832de0

File tree

4 files changed

+196
-56
lines changed

4 files changed

+196
-56
lines changed

include/swift/AST/ConstTypeInfo.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
#ifndef SWIFT_AST_CONST_TYPE_INFO_H
1414
#define SWIFT_AST_CONST_TYPE_INFO_H
1515

16+
#include "swift/AST/Type.h"
17+
#include <memory>
1618
#include <string>
1719
#include <vector>
18-
#include <memory>
1920

2021
namespace swift {
2122
class NominalTypeDecl;
@@ -55,13 +56,9 @@ class RawLiteralValue : public CompileTimeValue {
5556
};
5657

5758
struct FunctionParameter {
58-
public:
59-
std::string getLabel() { return Label; }
60-
swift::Type *getType() { return Type; }
61-
62-
private:
6359
std::string Label;
64-
swift::Type *Type;
60+
swift::Type Type;
61+
std::shared_ptr<CompileTimeValue> Value;
6562
};
6663

6764
/// A representation of a call to a type's initializer
@@ -77,6 +74,7 @@ class InitCallValue : public CompileTimeValue {
7774
}
7875

7976
std::string getName() const { return Name; }
77+
std::vector<FunctionParameter> getParameters() const { return Parameters; }
8078

8179
private:
8280
std::string Name;

include/swift/ConstExtract/ConstExtract.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ gatherConstValuesForModule(const std::unordered_set<std::string> &Protocols,
5656
/// provided output stream.
5757
bool writeAsJSONToFile(const std::vector<ConstValueTypeInfo> &ConstValueInfos,
5858
llvm::raw_fd_ostream &OS);
59-
6059
} // namespace swift
6160

6261
#endif

lib/ConstExtract/ConstExtract.cpp

Lines changed: 107 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
#include "swift/Basic/TypeID.h"
1413
#include "swift/ConstExtract/ConstExtract.h"
15-
#include "swift/ConstExtract/ConstExtractRequests.h"
1614
#include "swift/AST/ASTContext.h"
1715
#include "swift/AST/ASTWalker.h"
1816
#include "swift/AST/Decl.h"
@@ -21,11 +19,14 @@
2119
#include "swift/AST/Evaluator.h"
2220
#include "swift/AST/SourceFile.h"
2321
#include "swift/AST/TypeCheckRequests.h"
22+
#include "swift/Basic/TypeID.h"
23+
#include "swift/ConstExtract/ConstExtractRequests.h"
24+
#include "swift/Subsystems.h"
25+
#include "llvm/ADT/PointerUnion.h"
2426
#include "llvm/ADT/StringRef.h"
2527
#include "llvm/Support/JSON.h"
2628
#include "llvm/Support/YAMLParser.h"
2729
#include "llvm/Support/YAMLTraits.h"
28-
#include "swift/Subsystems.h"
2930

3031
#include <set>
3132
#include <sstream>
@@ -119,17 +120,66 @@ parseProtocolListFromFile(StringRef protocolListFilePath,
119120
return true;
120121
}
121122

123+
static std::string extractLiteralOutput(Expr *expr) {
124+
std::string LiteralOutput;
125+
llvm::raw_string_ostream OutputStream(LiteralOutput);
126+
expr->printConstExprValue(&OutputStream, nullptr);
127+
128+
return LiteralOutput;
129+
}
130+
122131
static std::shared_ptr<CompileTimeValue>
123132
extractPropertyInitializationValue(VarDecl *propertyDecl) {
124133
auto binding = propertyDecl->getParentPatternBinding();
125134
if (binding) {
126135
auto originalInit = binding->getOriginalInit(0);
127136
if (originalInit) {
128-
std::string LiteralOutput;
129-
llvm::raw_string_ostream OutputStream(LiteralOutput);
130-
originalInit->printConstExprValue(&OutputStream, nullptr);
131-
if (!LiteralOutput.empty())
132-
return std::make_shared<RawLiteralValue>(LiteralOutput);
137+
auto literalOutput = extractLiteralOutput(originalInit);
138+
if (!literalOutput.empty()) {
139+
return std::make_shared<RawLiteralValue>(literalOutput);
140+
}
141+
142+
if (auto callExpr = dyn_cast<CallExpr>(originalInit)) {
143+
if (callExpr->getFn()->getKind() != ExprKind::ConstructorRefCall) {
144+
return std::make_shared<RuntimeValue>();
145+
}
146+
147+
std::vector<FunctionParameter> parameters;
148+
const auto args = callExpr->getArgs();
149+
for (auto arg : *args) {
150+
auto label = arg.getLabel().str().str();
151+
auto expr = arg.getExpr();
152+
153+
switch (expr->getKind()) {
154+
case ExprKind::DefaultArgument: {
155+
auto defaultArgument = cast<DefaultArgumentExpr>(expr);
156+
auto *decl = defaultArgument->getParamDecl();
157+
158+
if (decl->hasDefaultExpr()) {
159+
literalOutput =
160+
extractLiteralOutput(decl->getTypeCheckedDefaultExpr());
161+
}
162+
163+
break;
164+
}
165+
default:
166+
literalOutput = extractLiteralOutput(expr);
167+
break;
168+
}
169+
170+
if (literalOutput.empty()) {
171+
parameters.push_back(
172+
{label, expr->getType(), std::make_shared<RuntimeValue>()});
173+
} else {
174+
parameters.push_back(
175+
{label, expr->getType(),
176+
std::make_shared<RawLiteralValue>(literalOutput)});
177+
}
178+
}
179+
180+
auto name = toFullyQualifiedTypeNameString(callExpr->getType());
181+
return std::make_shared<InitCallValue>(name, parameters);
182+
}
133183
}
134184
}
135185

@@ -138,9 +188,7 @@ extractPropertyInitializationValue(VarDecl *propertyDecl) {
138188
if (node.is<Stmt *>()) {
139189
if (auto returnStmt = dyn_cast<ReturnStmt>(node.get<Stmt *>())) {
140190
auto expr = returnStmt->getResult();
141-
std::string LiteralOutput;
142-
llvm::raw_string_ostream OutputStream(LiteralOutput);
143-
expr->printConstExprValue(&OutputStream, nullptr);
191+
std::string LiteralOutput = extractLiteralOutput(expr);
144192
if (!LiteralOutput.empty())
145193
return std::make_shared<RawLiteralValue>(LiteralOutput);
146194
}
@@ -207,33 +255,49 @@ gatherConstValuesForPrimary(const std::unordered_set<std::string> &Protocols,
207255
return Result;
208256
}
209257

210-
std::string toString(const CompileTimeValue *Value) {
211-
switch (Value->getKind()) {
212-
case CompileTimeValue::RawLiteral:
213-
return cast<RawLiteralValue>(Value)->getValue();
214-
case CompileTimeValue::InitCall:
215-
// TODO
216-
case CompileTimeValue::Builder:
217-
// TODO
218-
case CompileTimeValue::Dictionary:
219-
// TODO
220-
case CompileTimeValue::Runtime:
221-
return "Unknown";
258+
void writeValue(llvm::json::OStream &JSON,
259+
std::shared_ptr<CompileTimeValue> Value) {
260+
auto value = Value.get();
261+
switch (value->getKind()) {
262+
case CompileTimeValue::ValueKind::RawLiteral: {
263+
JSON.attribute("valueKind", "RawLiteral");
264+
JSON.attribute("value", cast<RawLiteralValue>(value)->getValue());
265+
break;
222266
}
223-
}
224267

225-
std::string toString(CompileTimeValue::ValueKind Kind) {
226-
switch (Kind) {
227-
case CompileTimeValue::ValueKind::RawLiteral:
228-
return "RawLiteral";
229-
case CompileTimeValue::ValueKind::InitCall:
230-
return "InitCall";
231-
case CompileTimeValue::ValueKind::Builder:
232-
return "Builder";
233-
case CompileTimeValue::ValueKind::Dictionary:
234-
return "Dictionary";
235-
case CompileTimeValue::ValueKind::Runtime:
236-
return "Runtime";
268+
case CompileTimeValue::ValueKind::InitCall: {
269+
auto initCallValue = cast<InitCallValue>(value);
270+
271+
JSON.attribute("valueKind", "InitCall");
272+
JSON.attributeObject("value", [&]() {
273+
JSON.attribute("type", initCallValue->getName());
274+
JSON.attributeArray("arguments", [&] {
275+
for (auto FP : initCallValue->getParameters()) {
276+
JSON.object([&] {
277+
JSON.attribute("label", FP.Label);
278+
JSON.attribute("type", toFullyQualifiedTypeNameString(FP.Type));
279+
writeValue(JSON, FP.Value);
280+
});
281+
}
282+
});
283+
});
284+
break;
285+
}
286+
287+
case CompileTimeValue::ValueKind::Builder: {
288+
JSON.attribute("valueKind", "Builder");
289+
break;
290+
}
291+
292+
case CompileTimeValue::ValueKind::Dictionary: {
293+
JSON.attribute("valueKind", "Dictionary");
294+
break;
295+
}
296+
297+
case CompileTimeValue::ValueKind::Runtime: {
298+
JSON.attribute("valueKind", "Runtime");
299+
break;
300+
}
237301
}
238302
}
239303

@@ -252,19 +316,14 @@ bool writeAsJSONToFile(const std::vector<ConstValueTypeInfo> &ConstValueInfos,
252316
JSON.attributeArray("properties", [&] {
253317
for (const auto &PropertyInfo : TypeInfo.Properties) {
254318
JSON.object([&] {
255-
const auto *PropertyDecl = PropertyInfo.VarDecl;
256-
JSON.attribute("label", PropertyDecl->getName().str().str());
257-
JSON.attribute("type", toFullyQualifiedTypeNameString(PropertyDecl->getType()));
258-
JSON.attribute("isStatic",
259-
PropertyDecl->isStatic() ? "true" : "false");
319+
const auto *decl = PropertyInfo.VarDecl;
320+
JSON.attribute("label", decl->getName().str().str());
321+
JSON.attribute("type",
322+
toFullyQualifiedTypeNameString(decl->getType()));
323+
JSON.attribute("isStatic", decl->isStatic() ? "true" : "false");
260324
JSON.attribute("isComputed",
261-
!PropertyDecl->hasStorage() ? "true" : "false");
262-
auto value = PropertyInfo.Value.get();
263-
auto valueKind = value->getKind();
264-
JSON.attribute("valueKind", toString(valueKind));
265-
if (valueKind != CompileTimeValue::ValueKind::Runtime) {
266-
JSON.attribute("value", toString(value));
267-
}
325+
!decl->hasStorage() ? "true" : "false");
326+
writeValue(JSON, PropertyInfo.Value);
268327
});
269328
}
270329
});

test/ConstExtraction/fields.swift

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,71 @@
5959
// CHECK-NEXT: "value": "[(\"One\", 1), (\"Two\", 2), (\"Three\", 3)]"
6060
// CHECK-NEXT: },
6161
// CHECK-NEXT: {
62+
// CHECK-NEXT: "label": "p10",
63+
// CHECK-NEXT: "type": "fields.Bar",
64+
// CHECK-NEXT: "isStatic": "false",
65+
// CHECK-NEXT: "isComputed": "false",
66+
// CHECK-NEXT: "valueKind": "InitCall",
67+
// CHECK-NEXT: "value": {
68+
// CHECK-NEXT: "type": "fields.Bar",
69+
// CHECK-NEXT: "arguments": []
70+
// CHECK-NEXT: }
71+
// CHECK-NEXT: },
72+
// CHECK-NEXT: {
73+
// CHECK-NEXT: "label": "p11",
74+
// CHECK-NEXT: "type": "fields.Bat",
75+
// CHECK-NEXT: "isStatic": "false",
76+
// CHECK-NEXT: "isComputed": "false",
77+
// CHECK-NEXT: "valueKind": "InitCall",
78+
// CHECK-NEXT: "value": {
79+
// CHECK-NEXT: "type": "fields.Bat",
80+
// CHECK-NEXT: "arguments": [
81+
// CHECK-NEXT: {
82+
// CHECK-NEXT: "label": "buz",
83+
// CHECK-NEXT: "type": "Swift.String",
84+
// CHECK-NEXT: "valueKind": "RawLiteral",
85+
// CHECK-NEXT: "value": "\"\""
86+
// CHECK-NEXT: },
87+
// CHECK-NEXT: {
88+
// CHECK-NEXT: "label": "fuz",
89+
// CHECK-NEXT: "type": "Swift.Int",
90+
// CHECK-NEXT: "valueKind": "RawLiteral",
91+
// CHECK-NEXT: "value": "0"
92+
// CHECK-NEXT: }
93+
// CHECK-NEXT: ]
94+
// CHECK-NEXT: }
95+
// CHECK-NEXT: },
96+
// CHECK-NEXT: {
97+
// CHECK-NEXT: "label": "p12",
98+
// CHECK-NEXT: "type": "fields.Bat",
99+
// CHECK-NEXT: "isStatic": "false",
100+
// CHECK-NEXT: "isComputed": "false",
101+
// CHECK-NEXT: "valueKind": "InitCall",
102+
// CHECK-NEXT: "value": {
103+
// CHECK-NEXT: "type": "fields.Bat",
104+
// CHECK-NEXT: "arguments": [
105+
// CHECK-NEXT: {
106+
// CHECK-NEXT: "label": "buz",
107+
// CHECK-NEXT: "type": "Swift.String",
108+
// CHECK-NEXT: "valueKind": "RawLiteral",
109+
// CHECK-NEXT: "value": "\"hello\""
110+
// CHECK-NEXT: },
111+
// CHECK-NEXT: {
112+
// CHECK-NEXT: "label": "fuz",
113+
// CHECK-NEXT: "type": "Swift.Int",
114+
// CHECK-NEXT: "valueKind": "Runtime"
115+
// CHECK-NEXT: }
116+
// CHECK-NEXT: ]
117+
// CHECK-NEXT: }
118+
// CHECK-NEXT: },
119+
// CHECK-NEXT: {
120+
// CHECK-NEXT: "label": "p13",
121+
// CHECK-NEXT: "type": "Swift.Int",
122+
// CHECK-NEXT: "isStatic": "false",
123+
// CHECK-NEXT: "isComputed": "false",
124+
// CHECK-NEXT: "valueKind": "Runtime"
125+
// CHECK-NEXT: },
126+
// CHECK-NEXT: {
62127
// CHECK-NEXT: "label": "p0",
63128
// CHECK-NEXT: "type": "Swift.Int",
64129
// CHECK-NEXT: "isStatic": "true",
@@ -107,6 +172,25 @@ public struct Foo {
107172
let p7: Bool? = nil
108173
let p8: (Int, Float) = (42, 6.6)
109174
let p9: [String: Int] = ["One": 1, "Two": 2, "Three": 3]
175+
let p10 = Bar()
176+
let p11: Bat = .init()
177+
let p12 = Bat(buz: "hello", fuz: adder(2, 3))
178+
let p13: Int = adder(2, 3)
179+
}
180+
181+
func adder(_ x: Int, _ y: Int) -> Int {
182+
x + y
183+
}
184+
185+
public struct Bar {}
186+
public struct Bat {
187+
let buz: String
188+
let fuz: Int
189+
190+
init(buz: String = "", fuz: Int = 0) {
191+
self.buz = buz
192+
self.fuz = fuz
193+
}
110194
}
111195

112196
extension Foo : MyProto {}

0 commit comments

Comments
 (0)