@@ -65,8 +65,9 @@ def asdl_of(name, obj):
65
65
class EmitVisitor (asdl .VisitorBase ):
66
66
"""Visit that emits lines"""
67
67
68
- def __init__ (self , file ):
68
+ def __init__ (self , file , typeinfo ):
69
69
self .file = file
70
+ self .typeinfo = typeinfo
70
71
self .identifiers = set ()
71
72
super (EmitVisitor , self ).__init__ ()
72
73
@@ -165,8 +166,7 @@ def rust_field(field_name):
165
166
166
167
class TypeInfoEmitVisitor (EmitVisitor ):
167
168
def __init__ (self , file , typeinfo ):
168
- self .typeinfo = typeinfo
169
- super ().__init__ (file )
169
+ super ().__init__ (file , typeinfo )
170
170
171
171
def has_userdata (self , typ ):
172
172
return self .typeinfo [typ ].has_userdata
@@ -327,7 +327,11 @@ def visitModule(self, mod, depth):
327
327
self .emit ("type Error;" , depth + 1 )
328
328
self .emit (
329
329
"fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;" ,
330
- depth + 2 ,
330
+ depth + 1 ,
331
+ )
332
+ self .emit (
333
+ "fn map_located<T>(&mut self, located: Located<T, U>) -> Result<Located<T, Self::TargetU>, Self::Error> { let custom = self.map_user(located.custom)?; Ok(Located { range: located.range, custom, node: located.node }) }" ,
334
+ depth + 1 ,
331
335
)
332
336
for dfn in mod .dfns :
333
337
self .visit (dfn , depth + 2 )
@@ -352,7 +356,7 @@ def visitModule(self, mod, depth):
352
356
depth ,
353
357
)
354
358
self .emit (
355
- "Ok(Located { custom: folder.map_user( node.custom)? , range: node.range, node: f(folder, node.node)? })" ,
359
+ "let node = folder.map_located(node)?; Ok(Located { custom: node.custom, range: node.range, node: f(folder, node.node)? })" ,
356
360
depth + 1 ,
357
361
)
358
362
self .emit ("}" , depth )
@@ -575,11 +579,15 @@ def visitSum(self, sum, name, depth):
575
579
rustname = enumname = get_rust_type (name )
576
580
if sum .attributes :
577
581
rustname = enumname + "Kind"
582
+ if sum .attributes or self .typeinfo [name ].has_userdata :
583
+ custom = "<LocationRange>"
584
+ else :
585
+ custom = ""
578
586
579
- self .emit (f"impl NamedNode for ast::{ rustname } {{" , depth )
587
+ self .emit (f"impl NamedNode for ast::{ rustname } { custom } {{" , depth )
580
588
self .emit (f"const NAME: &'static str = { json .dumps (name )} ;" , depth + 1 )
581
589
self .emit ("}" , depth )
582
- self .emit (f"impl Node for ast::{ rustname } {{" , depth )
590
+ self .emit (f"impl Node for ast::{ rustname } { custom } {{" , depth )
583
591
self .emit (
584
592
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {" , depth + 1
585
593
)
@@ -656,7 +664,10 @@ def gen_sum_fromobj(self, sum, sumname, enumname, rustname, depth):
656
664
for cons in sum .types :
657
665
self .emit (f"if _cls.is(Node{ cons .name } ::static_type()) {{" , depth )
658
666
if cons .fields :
659
- self .emit (f"ast::{ rustname } ::{ cons .name } (ast::{ enumname } { cons .name } {{" , depth + 1 )
667
+ self .emit (
668
+ f"ast::{ rustname } ::{ cons .name } (ast::{ enumname } { cons .name } ::<LocationRange> {{" ,
669
+ depth + 1 ,
670
+ )
660
671
self .gen_construction_fields (cons , sumname , depth + 1 )
661
672
self .emit ("})" , depth + 1 )
662
673
else :
@@ -691,7 +702,7 @@ def gen_construction(self, cons_path, cons, name, depth):
691
702
def extract_location (self , typename , depth ):
692
703
row = self .decode_field (asdl .Field ("int" , "lineno" ), typename )
693
704
column = self .decode_field (asdl .Field ("int" , "col_offset" ), typename )
694
- self .emit (f"let _location = ast:: Location::new({ row } , { column } );" , depth )
705
+ self .emit (f"let _location = Location::new({ row } , { column } );" , depth )
695
706
696
707
def decode_field (self , field , typename ):
697
708
name = json .dumps (field .name )
@@ -717,86 +728,20 @@ def write_ast_def(mod, typeinfo, f):
717
728
"""
718
729
#![allow(clippy::derive_partial_eq_without_eq)]
719
730
720
- pub use crate::constant::*;
721
- pub use rustpython_compiler_core::text_size::{TextSize, TextRange};
731
+ pub use crate::{Located, constant::*} ;
732
+ pub use rustpython_compiler_core::{ text_size::{TextSize, TextRange} };
722
733
723
734
type Ident = String;
724
735
\n
725
736
"""
726
737
)
727
738
)
728
- StructVisitor (f , typeinfo ).emit_attrs (0 )
729
- f .write (
730
- textwrap .dedent (
731
- """
732
- pub struct Located<T, U = ()> {
733
- pub range: TextRange,
734
- pub custom: U,
735
- pub node: T,
736
- }
737
-
738
- impl<T> Located<T> {
739
- pub fn new(start: TextSize, end: TextSize, node: T) -> Self {
740
- Self { range: TextRange::new(start, end), custom: (), node }
741
- }
742
-
743
- /// Creates a new node that spans the position specified by `range`.
744
- pub fn with_range(node: T, range: TextRange) -> Self {
745
- Self {
746
- range,
747
- custom: (),
748
- node,
749
- }
750
- }
751
-
752
- /// Returns the absolute start position of the node from the beginning of the document.
753
- #[inline]
754
- pub const fn start(&self) -> TextSize {
755
- self.range.start()
756
- }
757
-
758
- /// Returns the node
759
- #[inline]
760
- pub fn node(&self) -> &T {
761
- &self.node
762
- }
763
-
764
- /// Consumes self and returns the node.
765
- #[inline]
766
- pub fn into_node(self) -> T {
767
- self.node
768
- }
769
-
770
- /// Returns the `range` of the node. The range offsets are absolute to the start of the document.
771
- #[inline]
772
- pub const fn range(&self) -> TextRange {
773
- self.range
774
- }
775
-
776
- /// Returns the absolute position at which the node ends in the source document.
777
- #[inline]
778
- pub const fn end(&self) -> TextSize {
779
- self.range.end()
780
- }
781
- }
782
-
783
- impl<T, U> std::ops::Deref for Located<T, U> {
784
- type Target = T;
785
-
786
- fn deref(&self) -> &Self::Target {
787
- &self.node
788
- }
789
- }
790
- \n
791
- """ .lstrip ()
792
- )
793
- )
794
739
795
740
c = ChainOfVisitors (StructVisitor (f , typeinfo ), FoldModuleVisitor (f , typeinfo ))
796
741
c .visit (mod )
797
742
798
743
799
- def write_ast_mod (mod , f ):
744
+ def write_ast_mod (mod , typeinfo , f ):
800
745
f .write (
801
746
textwrap .dedent (
802
747
"""
@@ -809,7 +754,11 @@ def write_ast_mod(mod, f):
809
754
)
810
755
)
811
756
812
- c = ChainOfVisitors (ClassDefVisitor (f ), TraitImplVisitor (f ), ExtendModuleVisitor (f ))
757
+ c = ChainOfVisitors (
758
+ ClassDefVisitor (f , typeinfo ),
759
+ TraitImplVisitor (f , typeinfo ),
760
+ ExtendModuleVisitor (f , typeinfo ),
761
+ )
813
762
c .visit (mod )
814
763
815
764
@@ -830,7 +779,7 @@ def main(input_filename, ast_mod_filename, ast_def_filename, dump_module=False):
830
779
write_ast_def (mod , typeinfo , def_file )
831
780
832
781
mod_file .write (auto_gen_msg )
833
- write_ast_mod (mod , mod_file )
782
+ write_ast_mod (mod , typeinfo , mod_file )
834
783
835
784
print (f"{ ast_def_filename } , { ast_mod_filename } regenerated." )
836
785
0 commit comments