Skip to content

Commit e9abba2

Browse files
committed
[Serialization] Filter Decl to deserialize by their attributes
Add an alternative to getTopLevelDecls and getDeclChecked to limit which decls are deserialized by first looking at their attributes. If the attributes are accepted by a function passed as argument the decl is fully deserialized, otherwise it is ignored. The filter is included in the signature of existing functions in the Serilalization services, but I’ve added new methods for it in FileUnit and its subclasses to leave existing implementations untouched.
1 parent f240867 commit e9abba2

File tree

9 files changed

+135
-10
lines changed

9 files changed

+135
-10
lines changed

include/swift/AST/FileUnit.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,20 @@ class FileUnit : public DeclContext {
152152
/// The order of the results is not guaranteed to be meaningful.
153153
virtual void getTopLevelDecls(SmallVectorImpl<Decl*> &results) const {}
154154

155+
/// Finds top-level decls in this file filtered by their attributes.
156+
///
157+
/// This does a simple local lookup, not recursively looking through imports.
158+
/// The order of the results is not guaranteed to be meaningful.
159+
///
160+
/// \param Results Vector collecting the decls.
161+
///
162+
/// \param matchAttributes Check on the attributes of a decl to
163+
/// filter which decls to fully deserialize. Only decls with accepted
164+
/// attributes are deserialized and added to Results.
165+
virtual void
166+
getTopLevelDeclsWhereAttributesMatch(
167+
SmallVectorImpl<Decl*> &Results,
168+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const;
155169

156170
/// Finds all precedence group decls in this file.
157171
///

include/swift/AST/Module.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,20 @@ class ModuleDecl : public DeclContext, public TypeDecl {
476476
/// The order of the results is not guaranteed to be meaningful.
477477
void getTopLevelDecls(SmallVectorImpl<Decl*> &Results) const;
478478

479+
/// Finds top-level decls of this module filtered by their attributes.
480+
///
481+
/// This does a simple local lookup, not recursively looking through imports.
482+
/// The order of the results is not guaranteed to be meaningful.
483+
///
484+
/// \param Results Vector collecting the decls.
485+
///
486+
/// \param matchAttributes Check on the attributes of a decl to
487+
/// filter which decls to fully deserialize. Only decls with accepted
488+
/// attributes are deserialized and added to Results.
489+
void getTopLevelDeclsWhereAttributesMatch(
490+
SmallVectorImpl<Decl*> &Results,
491+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const;
492+
479493
/// Finds all local type decls of this module.
480494
///
481495
/// This does a simple local lookup, not recursively looking through imports.

include/swift/Serialization/SerializedModuleLoader.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,11 @@ class SerializedASTFile final : public LoadedFile {
351351

352352
virtual void getTopLevelDecls(SmallVectorImpl<Decl*> &results) const override;
353353

354+
virtual void
355+
getTopLevelDeclsWhereAttributesMatch(
356+
SmallVectorImpl<Decl*> &Results,
357+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const override;
358+
354359
virtual void
355360
getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl*> &Results) const override;
356361

lib/AST/Module.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,12 @@ void ModuleDecl::getTopLevelDecls(SmallVectorImpl<Decl*> &Results) const {
644644
FORWARD(getTopLevelDecls, (Results));
645645
}
646646

647+
void ModuleDecl::getTopLevelDeclsWhereAttributesMatch(
648+
SmallVectorImpl<Decl*> &Results,
649+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const {
650+
FORWARD(getTopLevelDeclsWhereAttributesMatch, (Results, matchAttributes));
651+
}
652+
647653
void SourceFile::getTopLevelDecls(SmallVectorImpl<Decl*> &Results) const {
648654
Results.append(Decls.begin(), Decls.end());
649655
}
@@ -1873,6 +1879,22 @@ void *FileUnit::operator new(size_t Bytes, ASTContext &C, unsigned Alignment) {
18731879
return C.Allocate(Bytes, Alignment);
18741880
}
18751881

1882+
void FileUnit::getTopLevelDeclsWhereAttributesMatch(
1883+
SmallVectorImpl<Decl*> &Results,
1884+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const {
1885+
auto prevSize = Results.size();
1886+
getTopLevelDecls(Results);
1887+
1888+
// Filter out unwanted decls that were just added to Results.
1889+
// Note: We could apply this check in all implementations of
1890+
// getTopLevelDecls instead or in everything that creates a Decl.
1891+
auto newEnd = std::remove_if(Results.begin() + prevSize, Results.end(),
1892+
[&matchAttributes](const Decl *D) -> bool {
1893+
return !matchAttributes(D->getAttrs());
1894+
});
1895+
Results.erase(newEnd, Results.end());
1896+
}
1897+
18761898
StringRef LoadedFile::getFilename() const {
18771899
return "";
18781900
}

lib/Serialization/Deserialization.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ const char TypeError::ID = '\0';
136136
void TypeError::anchor() {}
137137
const char ExtensionError::ID = '\0';
138138
void ExtensionError::anchor() {}
139+
const char DeclAttributesDidNotMatch::ID = '\0';
140+
void DeclAttributesDidNotMatch::anchor() {}
139141

140142
/// Skips a single record in the bitstream.
141143
///
@@ -2280,7 +2282,8 @@ class swift::DeclDeserializer {
22802282
/// passing each one to AddAttribute.
22812283
llvm::Error deserializeDeclAttributes();
22822284

2283-
Expected<Decl *> getDeclCheckedImpl();
2285+
Expected<Decl *> getDeclCheckedImpl(
2286+
llvm::function_ref<bool(DeclAttributes)> matchAttributes = nullptr);
22842287

22852288
Expected<Decl *> deserializeTypeAlias(ArrayRef<uint64_t> scratch,
22862289
StringRef blobData) {
@@ -3855,7 +3858,9 @@ class swift::DeclDeserializer {
38553858
};
38563859

38573860
Expected<Decl *>
3858-
ModuleFile::getDeclChecked(DeclID DID) {
3861+
ModuleFile::getDeclChecked(
3862+
DeclID DID,
3863+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) {
38593864
if (DID == 0)
38603865
return nullptr;
38613866

@@ -3868,9 +3873,14 @@ ModuleFile::getDeclChecked(DeclID DID) {
38683873
fatalIfNotSuccess(DeclTypeCursor.JumpToBit(declOrOffset));
38693874

38703875
Expected<Decl *> deserialized =
3871-
DeclDeserializer(*this, declOrOffset).getDeclCheckedImpl();
3876+
DeclDeserializer(*this, declOrOffset).getDeclCheckedImpl(
3877+
matchAttributes);
38723878
if (!deserialized)
38733879
return deserialized;
3880+
} else if (matchAttributes) {
3881+
// Decl was cached but we may need to filter it
3882+
if (!matchAttributes(declOrOffset.get()->getAttrs()))
3883+
return llvm::make_error<DeclAttributesDidNotMatch>();
38743884
}
38753885

38763886
// Tag every deserialized ValueDecl coming out of getDeclChecked with its ID.
@@ -4181,14 +4191,24 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
41814191
}
41824192

41834193
Expected<Decl *>
4184-
DeclDeserializer::getDeclCheckedImpl() {
4185-
if (auto s = ctx.Stats)
4186-
s->getFrontendCounters().NumDeclsDeserialized++;
4194+
DeclDeserializer::getDeclCheckedImpl(
4195+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) {
41874196

41884197
auto attrError = deserializeDeclAttributes();
41894198
if (attrError)
41904199
return std::move(attrError);
41914200

4201+
if (matchAttributes) {
4202+
// Deserialize the full decl only if matchAttributes finds a match.
4203+
DeclAttributes attrs = DeclAttributes();
4204+
attrs.setRawAttributeChain(DAttrs);
4205+
if (!matchAttributes(attrs))
4206+
return llvm::make_error<DeclAttributesDidNotMatch>();
4207+
}
4208+
4209+
if (auto s = ctx.Stats)
4210+
s->getFrontendCounters().NumDeclsDeserialized++;
4211+
41924212
// FIXME: @_dynamicReplacement(for:) includes a reference to another decl,
41934213
// usually in the same type, and that can result in this decl being
41944214
// re-entrantly deserialized. If that happens, don't fail here.

lib/Serialization/DeserializationErrors.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,26 @@ class SILEntityError : public llvm::ErrorInfo<SILEntityError> {
411411
}
412412
};
413413

414+
// Decl was not deserialized because its attributes did not match the filter.
415+
//
416+
// \sa getDeclChecked
417+
class DeclAttributesDidNotMatch : public llvm::ErrorInfo<DeclAttributesDidNotMatch> {
418+
friend ErrorInfo;
419+
static const char ID;
420+
void anchor() override;
421+
422+
public:
423+
DeclAttributesDidNotMatch() {}
424+
425+
void log(raw_ostream &OS) const override {
426+
OS << "Decl attributes did not match filter";
427+
}
428+
429+
std::error_code convertToErrorCode() const override {
430+
return llvm::inconvertibleErrorCode();
431+
}
432+
};
433+
414434
LLVM_NODISCARD
415435
static inline std::unique_ptr<llvm::ErrorInfoBase>
416436
takeErrorInfo(llvm::Error error) {

lib/Serialization/ModuleFile.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,11 +2539,20 @@ ModuleFile::collectLinkLibraries(ModuleDecl::LinkLibraryCallback callback) const
25392539
callback(LinkLibrary(Name, LibraryKind::Framework));
25402540
}
25412541

2542-
void ModuleFile::getTopLevelDecls(SmallVectorImpl<Decl *> &results) {
2542+
void ModuleFile::getTopLevelDecls(
2543+
SmallVectorImpl<Decl *> &results,
2544+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) {
25432545
PrettyStackTraceModuleFile stackEntry(*this);
25442546
for (DeclID entry : OrderedTopLevelDecls) {
2545-
Expected<Decl *> declOrError = getDeclChecked(entry);
2547+
Expected<Decl *> declOrError = getDeclChecked(entry, matchAttributes);
25462548
if (!declOrError) {
2549+
if (declOrError.errorIsA<DeclAttributesDidNotMatch>()) {
2550+
// Decl rejected by matchAttributes, ignore it.
2551+
assert(matchAttributes);
2552+
consumeError(declOrError.takeError());
2553+
continue;
2554+
}
2555+
25472556
if (!getContext().LangOpts.EnableDeserializationRecovery)
25482557
fatal(declOrError.takeError());
25492558
consumeError(declOrError.takeError());

lib/Serialization/ModuleFile.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,15 @@ class ModuleFile
786786
void collectLinkLibraries(ModuleDecl::LinkLibraryCallback callback) const;
787787

788788
/// Adds all top-level decls to the given vector.
789-
void getTopLevelDecls(SmallVectorImpl<Decl*> &Results);
789+
///
790+
/// \param Results Vector collecting the decls.
791+
///
792+
/// \param matchAttributes Optional check on the attributes of a decl to
793+
/// filter which decls to fully deserialize. Only decls with accepted
794+
/// attributes are deserialized and added to Results.
795+
void getTopLevelDecls(
796+
SmallVectorImpl<Decl*> &Results,
797+
llvm::function_ref<bool(DeclAttributes)> matchAttributes = nullptr);
790798

791799
/// Adds all precedence groups to the given vector.
792800
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl*> &Results);
@@ -888,8 +896,15 @@ class ModuleFile
888896
/// Returns the decl with the given ID, deserializing it if needed.
889897
///
890898
/// \param DID The ID for the decl within this module.
899+
///
900+
/// \param matchAttributes Optional check on the attributes of the decl to
901+
/// determine if it should be fully deserialized and returned. If the
902+
/// attributes fail the check, the decl is not deserialized and
903+
/// \c DeclAttributesDidNotMatch is returned.
891904
llvm::Expected<Decl *>
892-
getDeclChecked(serialization::DeclID DID);
905+
getDeclChecked(
906+
serialization::DeclID DID,
907+
llvm::function_ref<bool(DeclAttributes)> matchAttributes = nullptr);
893908

894909
/// Returns the decl context with the given ID, deserializing it if needed.
895910
DeclContext *getDeclContext(serialization::DeclContextID DID);

lib/Serialization/SerializedModuleLoader.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,12 @@ SerializedASTFile::getTopLevelDecls(SmallVectorImpl<Decl*> &results) const {
11201120
File.getTopLevelDecls(results);
11211121
}
11221122

1123+
void SerializedASTFile::getTopLevelDeclsWhereAttributesMatch(
1124+
SmallVectorImpl<Decl*> &results,
1125+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const {
1126+
File.getTopLevelDecls(results, matchAttributes);
1127+
}
1128+
11231129
void SerializedASTFile::getPrecedenceGroups(
11241130
SmallVectorImpl<PrecedenceGroupDecl*> &results) const {
11251131
File.getPrecedenceGroups(results);

0 commit comments

Comments
 (0)