Skip to content

Commit 3d9587c

Browse files
authored
Merge pull request #28415 from xymus/filter-by-attrs
[Serialization & ClangImporter] Filter decl to deserialize by their `@objc` attributes
2 parents 53e2b5c + 7986024 commit 3d9587c

File tree

10 files changed

+154
-20
lines changed

10 files changed

+154
-20
lines changed

include/swift/AST/FileUnit.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,21 @@ class FileUnit : public DeclContext {
152152
/// The order of the results is not guaranteed to be meaningful.
153153
virtual void getTopLevelDecls(SmallVectorImpl<Decl*> &results) const {}
154154

155+
/// Finds top-level decls in this file filtered by their attributes.
156+
///
157+
/// This does a simple local lookup, not recursively looking through imports.
158+
/// The order of the results is not guaranteed to be meaningful.
159+
///
160+
/// \param Results Vector collecting the decls.
161+
///
162+
/// \param matchAttributes Check on the attributes of a decl to keep only
163+
/// decls with matching attributes. The subclass SerializedASTFile checks the
164+
/// attributes first to only deserialize decls with accepted attributes,
165+
/// limiting deserialization work.
166+
virtual void
167+
getTopLevelDeclsWhereAttributesMatch(
168+
SmallVectorImpl<Decl*> &Results,
169+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const;
155170

156171
/// Finds all precedence group decls in this file.
157172
///

include/swift/AST/Module.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,20 @@ class ModuleDecl : public DeclContext, public TypeDecl {
476476
/// The order of the results is not guaranteed to be meaningful.
477477
void getTopLevelDecls(SmallVectorImpl<Decl*> &Results) const;
478478

479+
/// Finds top-level decls of this module filtered by their attributes.
480+
///
481+
/// This does a simple local lookup, not recursively looking through imports.
482+
/// The order of the results is not guaranteed to be meaningful.
483+
///
484+
/// \param Results Vector collecting the decls.
485+
///
486+
/// \param matchAttributes Check on the attributes of a decl to
487+
/// filter which decls to fully deserialize. Only decls with accepted
488+
/// attributes are deserialized and added to Results.
489+
void getTopLevelDeclsWhereAttributesMatch(
490+
SmallVectorImpl<Decl*> &Results,
491+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const;
492+
479493
/// Finds all local type decls of this module.
480494
///
481495
/// This does a simple local lookup, not recursively looking through imports.

include/swift/Serialization/SerializedModuleLoader.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,11 @@ class SerializedASTFile final : public LoadedFile {
351351

352352
virtual void getTopLevelDecls(SmallVectorImpl<Decl*> &results) const override;
353353

354+
virtual void
355+
getTopLevelDeclsWhereAttributesMatch(
356+
SmallVectorImpl<Decl*> &Results,
357+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const override;
358+
354359
virtual void
355360
getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl*> &Results) const override;
356361

lib/AST/Module.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,12 @@ void ModuleDecl::getTopLevelDecls(SmallVectorImpl<Decl*> &Results) const {
644644
FORWARD(getTopLevelDecls, (Results));
645645
}
646646

647+
void ModuleDecl::getTopLevelDeclsWhereAttributesMatch(
648+
SmallVectorImpl<Decl*> &Results,
649+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const {
650+
FORWARD(getTopLevelDeclsWhereAttributesMatch, (Results, matchAttributes));
651+
}
652+
647653
void SourceFile::getTopLevelDecls(SmallVectorImpl<Decl*> &Results) const {
648654
Results.append(Decls.begin(), Decls.end());
649655
}
@@ -1873,6 +1879,22 @@ void *FileUnit::operator new(size_t Bytes, ASTContext &C, unsigned Alignment) {
18731879
return C.Allocate(Bytes, Alignment);
18741880
}
18751881

1882+
void FileUnit::getTopLevelDeclsWhereAttributesMatch(
1883+
SmallVectorImpl<Decl*> &Results,
1884+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const {
1885+
auto prevSize = Results.size();
1886+
getTopLevelDecls(Results);
1887+
1888+
// Filter out unwanted decls that were just added to Results.
1889+
// Note: We could apply this check in all implementations of
1890+
// getTopLevelDecls instead or in everything that creates a Decl.
1891+
auto newEnd = std::remove_if(Results.begin() + prevSize, Results.end(),
1892+
[&matchAttributes](const Decl *D) -> bool {
1893+
return !matchAttributes(D->getAttrs());
1894+
});
1895+
Results.erase(newEnd, Results.end());
1896+
}
1897+
18761898
StringRef LoadedFile::getFilename() const {
18771899
return "";
18781900
}

lib/ClangImporter/ImportDecl.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4563,17 +4563,25 @@ namespace {
45634563

45644564
if (!found) {
45654565
// Try harder to find a match looking at just custom Objective-C names.
4566-
SmallVector<Decl *, 64> allTopLevelDecls;
4567-
overlay->getTopLevelDecls(allTopLevelDecls);
4568-
for (auto result : allTopLevelDecls) {
4566+
// Limit what we deserialize to decls with an @objc attribute.
4567+
SmallVector<Decl *, 4> matchingTopLevelDecls;
4568+
4569+
// Get decls with a matching @objc attribute
4570+
overlay->getTopLevelDeclsWhereAttributesMatch(
4571+
matchingTopLevelDecls,
4572+
[&name](const DeclAttributes attrs) -> bool {
4573+
if (auto objcAttr = attrs.getAttribute<ObjCAttr>())
4574+
if (auto objcName = objcAttr->getName())
4575+
return objcName->getSimpleName() == name;
4576+
return false;
4577+
});
4578+
4579+
// Filter by decl kind
4580+
for (auto result : matchingTopLevelDecls) {
45694581
if (auto singleResult = dyn_cast<T>(result)) {
4570-
// The base name _could_ match but it's irrelevant here.
4571-
if (isMatch(singleResult, /*baseNameMatches=*/false,
4572-
/*allowObjCMismatch=*/false)) {
4573-
if (found)
4574-
return nullptr;
4575-
found = singleResult;
4576-
}
4582+
if (found)
4583+
return nullptr;
4584+
found = singleResult;
45774585
}
45784586
}
45794587
}

lib/Serialization/Deserialization.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ const char TypeError::ID = '\0';
136136
void TypeError::anchor() {}
137137
const char ExtensionError::ID = '\0';
138138
void ExtensionError::anchor() {}
139+
const char DeclAttributesDidNotMatch::ID = '\0';
140+
void DeclAttributesDidNotMatch::anchor() {}
139141

140142
/// Skips a single record in the bitstream.
141143
///
@@ -2280,7 +2282,8 @@ class swift::DeclDeserializer {
22802282
/// passing each one to AddAttribute.
22812283
llvm::Error deserializeDeclAttributes();
22822284

2283-
Expected<Decl *> getDeclCheckedImpl();
2285+
Expected<Decl *> getDeclCheckedImpl(
2286+
llvm::function_ref<bool(DeclAttributes)> matchAttributes = nullptr);
22842287

22852288
Expected<Decl *> deserializeTypeAlias(ArrayRef<uint64_t> scratch,
22862289
StringRef blobData) {
@@ -3855,7 +3858,9 @@ class swift::DeclDeserializer {
38553858
};
38563859

38573860
Expected<Decl *>
3858-
ModuleFile::getDeclChecked(DeclID DID) {
3861+
ModuleFile::getDeclChecked(
3862+
DeclID DID,
3863+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) {
38593864
if (DID == 0)
38603865
return nullptr;
38613866

@@ -3868,9 +3873,14 @@ ModuleFile::getDeclChecked(DeclID DID) {
38683873
fatalIfNotSuccess(DeclTypeCursor.JumpToBit(declOrOffset));
38693874

38703875
Expected<Decl *> deserialized =
3871-
DeclDeserializer(*this, declOrOffset).getDeclCheckedImpl();
3876+
DeclDeserializer(*this, declOrOffset).getDeclCheckedImpl(
3877+
matchAttributes);
38723878
if (!deserialized)
38733879
return deserialized;
3880+
} else if (matchAttributes) {
3881+
// Decl was cached but we may need to filter it
3882+
if (!matchAttributes(declOrOffset.get()->getAttrs()))
3883+
return llvm::make_error<DeclAttributesDidNotMatch>();
38743884
}
38753885

38763886
// Tag every deserialized ValueDecl coming out of getDeclChecked with its ID.
@@ -4203,14 +4213,24 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
42034213
}
42044214

42054215
Expected<Decl *>
4206-
DeclDeserializer::getDeclCheckedImpl() {
4207-
if (auto s = ctx.Stats)
4208-
s->getFrontendCounters().NumDeclsDeserialized++;
4216+
DeclDeserializer::getDeclCheckedImpl(
4217+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) {
42094218

42104219
auto attrError = deserializeDeclAttributes();
42114220
if (attrError)
42124221
return std::move(attrError);
42134222

4223+
if (matchAttributes) {
4224+
// Deserialize the full decl only if matchAttributes finds a match.
4225+
DeclAttributes attrs = DeclAttributes();
4226+
attrs.setRawAttributeChain(DAttrs);
4227+
if (!matchAttributes(attrs))
4228+
return llvm::make_error<DeclAttributesDidNotMatch>();
4229+
}
4230+
4231+
if (auto s = ctx.Stats)
4232+
s->getFrontendCounters().NumDeclsDeserialized++;
4233+
42144234
// FIXME: @_dynamicReplacement(for:) includes a reference to another decl,
42154235
// usually in the same type, and that can result in this decl being
42164236
// re-entrantly deserialized. If that happens, don't fail here.

lib/Serialization/DeserializationErrors.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,26 @@ class SILEntityError : public llvm::ErrorInfo<SILEntityError> {
411411
}
412412
};
413413

414+
// Decl was not deserialized because its attributes did not match the filter.
415+
//
416+
// \sa getDeclChecked
417+
class DeclAttributesDidNotMatch : public llvm::ErrorInfo<DeclAttributesDidNotMatch> {
418+
friend ErrorInfo;
419+
static const char ID;
420+
void anchor() override;
421+
422+
public:
423+
DeclAttributesDidNotMatch() {}
424+
425+
void log(raw_ostream &OS) const override {
426+
OS << "Decl attributes did not match filter";
427+
}
428+
429+
std::error_code convertToErrorCode() const override {
430+
return llvm::inconvertibleErrorCode();
431+
}
432+
};
433+
414434
LLVM_NODISCARD
415435
static inline std::unique_ptr<llvm::ErrorInfoBase>
416436
takeErrorInfo(llvm::Error error) {

lib/Serialization/ModuleFile.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,11 +2539,20 @@ ModuleFile::collectLinkLibraries(ModuleDecl::LinkLibraryCallback callback) const
25392539
callback(LinkLibrary(Name, LibraryKind::Framework));
25402540
}
25412541

2542-
void ModuleFile::getTopLevelDecls(SmallVectorImpl<Decl *> &results) {
2542+
void ModuleFile::getTopLevelDecls(
2543+
SmallVectorImpl<Decl *> &results,
2544+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) {
25432545
PrettyStackTraceModuleFile stackEntry(*this);
25442546
for (DeclID entry : OrderedTopLevelDecls) {
2545-
Expected<Decl *> declOrError = getDeclChecked(entry);
2547+
Expected<Decl *> declOrError = getDeclChecked(entry, matchAttributes);
25462548
if (!declOrError) {
2549+
if (declOrError.errorIsA<DeclAttributesDidNotMatch>()) {
2550+
// Decl rejected by matchAttributes, ignore it.
2551+
assert(matchAttributes);
2552+
consumeError(declOrError.takeError());
2553+
continue;
2554+
}
2555+
25472556
if (!getContext().LangOpts.EnableDeserializationRecovery)
25482557
fatal(declOrError.takeError());
25492558
consumeError(declOrError.takeError());

lib/Serialization/ModuleFile.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,15 @@ class ModuleFile
786786
void collectLinkLibraries(ModuleDecl::LinkLibraryCallback callback) const;
787787

788788
/// Adds all top-level decls to the given vector.
789-
void getTopLevelDecls(SmallVectorImpl<Decl*> &Results);
789+
///
790+
/// \param Results Vector collecting the decls.
791+
///
792+
/// \param matchAttributes Optional check on the attributes of a decl to
793+
/// filter which decls to fully deserialize. Only decls with accepted
794+
/// attributes are deserialized and added to Results.
795+
void getTopLevelDecls(
796+
SmallVectorImpl<Decl*> &Results,
797+
llvm::function_ref<bool(DeclAttributes)> matchAttributes = nullptr);
790798

791799
/// Adds all precedence groups to the given vector.
792800
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl*> &Results);
@@ -888,8 +896,15 @@ class ModuleFile
888896
/// Returns the decl with the given ID, deserializing it if needed.
889897
///
890898
/// \param DID The ID for the decl within this module.
899+
///
900+
/// \param matchAttributes Optional check on the attributes of the decl to
901+
/// determine if it should be fully deserialized and returned. If the
902+
/// attributes fail the check, the decl is not deserialized and
903+
/// \c DeclAttributesDidNotMatch is returned.
891904
llvm::Expected<Decl *>
892-
getDeclChecked(serialization::DeclID DID);
905+
getDeclChecked(
906+
serialization::DeclID DID,
907+
llvm::function_ref<bool(DeclAttributes)> matchAttributes = nullptr);
893908

894909
/// Returns the decl context with the given ID, deserializing it if needed.
895910
DeclContext *getDeclContext(serialization::DeclContextID DID);

lib/Serialization/SerializedModuleLoader.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,12 @@ SerializedASTFile::getTopLevelDecls(SmallVectorImpl<Decl*> &results) const {
11201120
File.getTopLevelDecls(results);
11211121
}
11221122

1123+
void SerializedASTFile::getTopLevelDeclsWhereAttributesMatch(
1124+
SmallVectorImpl<Decl*> &results,
1125+
llvm::function_ref<bool(DeclAttributes)> matchAttributes) const {
1126+
File.getTopLevelDecls(results, matchAttributes);
1127+
}
1128+
11231129
void SerializedASTFile::getPrecedenceGroups(
11241130
SmallVectorImpl<PrecedenceGroupDecl*> &results) const {
11251131
File.getPrecedenceGroups(results);

0 commit comments

Comments
 (0)