Skip to content

Commit 4f0c6fa

Browse files
committed
fix: only generate trait bound for associated types in field types
1 parent 1c25885 commit 4f0c6fa

File tree

3 files changed

+139
-71
lines changed

3 files changed

+139
-71
lines changed

crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,66 @@ impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Com
114114
);
115115
}
116116

117+
#[test]
118+
fn test_clone_expand_with_associated_types() {
119+
check(
120+
r#"
121+
//- minicore: derive, clone
122+
trait Trait {
123+
type InWc;
124+
type InFieldQualified;
125+
type InFieldShorthand;
126+
type InGenericArg;
127+
}
128+
trait Marker {}
129+
struct Vec<T>(T);
130+
131+
#[derive(Clone)]
132+
struct Foo<T: Trait>
133+
where
134+
<T as Trait>::InWc: Marker,
135+
{
136+
qualified: <T as Trait>::InFieldQualified,
137+
shorthand: T::InFieldShorthand,
138+
generic: Vec<T::InGenericArg>,
139+
}
140+
"#,
141+
expect![[r#"
142+
trait Trait {
143+
type InWc;
144+
type InFieldQualified;
145+
type InFieldShorthand;
146+
type InGenericArg;
147+
}
148+
trait Marker {}
149+
struct Vec<T>(T);
150+
151+
#[derive(Clone)]
152+
struct Foo<T: Trait>
153+
where
154+
<T as Trait>::InWc: Marker,
155+
{
156+
qualified: <T as Trait>::InFieldQualified,
157+
shorthand: T::InFieldShorthand,
158+
generic: Vec<T::InGenericArg>,
159+
}
160+
161+
impl <T: core::clone::Clone, > core::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, {
162+
fn clone(&self ) -> Self {
163+
match self {
164+
Foo {
165+
qualified: qualified, shorthand: shorthand, generic: generic,
166+
}
167+
=>Foo {
168+
qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(),
169+
}
170+
,
171+
}
172+
}
173+
}"#]],
174+
);
175+
}
176+
117177
#[test]
118178
fn test_clone_expand_with_const_generics() {
119179
check(

crates/hir-expand/src/builtin_derive_macro.rs

Lines changed: 74 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@ use ::tt::Ident;
44
use base_db::{CrateOrigin, LangCrateOrigin};
55
use itertools::izip;
66
use mbe::TokenMap;
7-
use std::collections::HashSet;
7+
use rustc_hash::FxHashSet;
88
use stdx::never;
99
use tracing::debug;
1010

11-
use crate::tt::{self, TokenId};
12-
use syntax::{
13-
ast::{
14-
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName,
15-
HasTypeBounds, PathType,
16-
},
17-
match_ast,
11+
use crate::{
12+
name::{AsName, Name},
13+
tt::{self, TokenId},
14+
};
15+
use syntax::ast::{
16+
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
1817
};
1918

2019
use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
@@ -201,41 +200,54 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
201200
debug!("no module item parsed");
202201
ExpandError::Other("no item found".into())
203202
})?;
204-
let node = item.syntax();
205-
let (name, params, shape) = match_ast! {
206-
match node {
207-
ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)),
208-
ast::Enum(it) => {
209-
let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
210-
(
211-
it.name(),
212-
it.generic_param_list(),
213-
AdtShape::Enum {
214-
default_variant,
215-
variants: it.variant_list()
216-
.into_iter()
217-
.flat_map(|x| x.variants())
218-
.map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()?
219-
}
220-
)
221-
},
222-
ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
223-
_ => {
224-
debug!("unexpected node is {:?}", node);
225-
return Err(ExpandError::Other("expected struct, enum or union".into()))
226-
},
203+
let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
204+
debug!("expected adt, found: {:?}", item);
205+
ExpandError::Other("expected struct, enum or union".into())
206+
})?;
207+
let (name, generic_param_list, shape) = match &adt {
208+
ast::Adt::Struct(it) => (
209+
it.name(),
210+
it.generic_param_list(),
211+
AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
212+
),
213+
ast::Adt::Enum(it) => {
214+
let default_variant = it
215+
.variant_list()
216+
.into_iter()
217+
.flat_map(|x| x.variants())
218+
.position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
219+
(
220+
it.name(),
221+
it.generic_param_list(),
222+
AdtShape::Enum {
223+
default_variant,
224+
variants: it
225+
.variant_list()
226+
.into_iter()
227+
.flat_map(|x| x.variants())
228+
.map(|x| {
229+
Ok((
230+
name_to_token(&token_map, x.name())?,
231+
VariantShape::from(x.field_list(), &token_map)?,
232+
))
233+
})
234+
.collect::<Result<_, ExpandError>>()?,
235+
},
236+
)
227237
}
238+
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
228239
};
229-
let mut param_type_set: HashSet<String> = HashSet::new();
230-
let param_types = params
240+
241+
let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
242+
let param_types = generic_param_list
231243
.into_iter()
232244
.flat_map(|param_list| param_list.type_or_const_params())
233245
.map(|param| {
234246
let name = {
235247
let this = param.name();
236248
match this {
237249
Some(x) => {
238-
param_type_set.insert(x.to_string());
250+
param_type_set.insert(x.as_name());
239251
mbe::syntax_node_to_token_tree(x.syntax()).0
240252
}
241253
None => tt::Subtree::empty(),
@@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
259271
(name, ty, bounds)
260272
})
261273
.collect();
262-
let is_associated_type = |p: &PathType| {
263-
if let Some(p) = p.path() {
264-
if let Some(parent) = p.qualifier() {
265-
if let Some(x) = parent.segment() {
266-
if let Some(x) = x.path_type() {
267-
if let Some(x) = x.path() {
268-
if let Some(pname) = x.as_single_name_ref() {
269-
if param_type_set.contains(&pname.to_string()) {
270-
// <T as Trait>::Assoc
271-
return true;
272-
}
273-
}
274-
}
275-
}
276-
}
277-
if let Some(pname) = parent.as_single_name_ref() {
278-
if param_type_set.contains(&pname.to_string()) {
279-
// T::Assoc
280-
return true;
281-
}
282-
}
283-
}
284-
}
285-
false
274+
275+
// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
276+
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
277+
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
278+
// does not do that for some unknown reason.
279+
//
280+
// See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
281+
// [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
282+
283+
// It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
284+
// `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
285+
// we should not inspect `ast::PathType`s in parameter bounds and where clauses.
286+
let field_list = match adt {
287+
ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
288+
ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
289+
ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
286290
};
287-
let associated_types = node
288-
.descendants()
289-
.filter_map(PathType::cast)
290-
.filter(is_associated_type)
291+
let associated_types = field_list
292+
.into_iter()
293+
.flat_map(|it| it.descendants())
294+
.filter_map(ast::PathType::cast)
295+
.filter_map(|p| {
296+
let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
297+
param_type_set.contains(&name).then_some(p)
298+
})
291299
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
292-
.collect::<Vec<_>>();
300+
.collect();
293301
let name_token = name_to_token(&token_map, name)?;
294302
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
295303
}
@@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
334342
/// }
335343
/// ```
336344
///
337-
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
345+
/// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
338346
/// therefore does not get bound by the derived trait.
339347
fn expand_simple_derive(
340348
tt: &tt::Subtree,
341349
trait_path: tt::Subtree,
342-
trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
350+
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
343351
) -> ExpandResult<tt::Subtree> {
344352
let info = match parse_adt(tt) {
345353
Ok(info) => info,
346354
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
347355
};
348-
let trait_body = trait_body(&info);
356+
let trait_body = make_trait_body(&info);
349357
let mut where_block = vec![];
350358
let (params, args): (Vec<_>, Vec<_>) = info
351359
.param_types

crates/hir-ty/src/tests/traits.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4335,8 +4335,9 @@ fn derive_macro_bounds() {
43354335
#[derive(Clone)]
43364336
struct AssocGeneric<T: Tr>(T::Assoc);
43374337
4338-
#[derive(Clone)]
4339-
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
4338+
// Currently rustc does not accept this.
4339+
// #[derive(Clone)]
4340+
// struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
43404341
43414342
#[derive(Clone)]
43424343
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
@@ -4361,9 +4362,8 @@ fn derive_macro_bounds() {
43614362
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
43624363
let x = x.clone();
43634364
//^ &AssocGeneric<Copy>
4364-
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
4365-
let x = x.clone();
4366-
//^ &AssocGeneric2<Copy>
4365+
// let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
4366+
// let x = x.clone();
43674367
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
43684368
let x = x.clone();
43694369
//^ &AssocGeneric3<Copy>

0 commit comments

Comments
 (0)