@@ -74,16 +74,20 @@ def generate_anchors(
74
74
return base_anchors .round ()
75
75
76
76
def set_cell_anchors (self , dtype : torch .dtype , device : torch .device ):
77
- self . cell_anchors = [cell_anchor .to (dtype = dtype , device = device ) for cell_anchor in self .cell_anchors ]
77
+ return [cell_anchor .to (dtype = dtype , device = device ) for cell_anchor in self .cell_anchors ]
78
78
79
79
def num_anchors_per_location (self ) -> list [int ]:
80
80
return [len (s ) * len (a ) for s , a in zip (self .sizes , self .aspect_ratios )]
81
81
82
82
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
83
83
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
84
- def grid_anchors (self , grid_sizes : list [list [int ]], strides : list [list [Tensor ]]) -> list [Tensor ]:
84
+ def grid_anchors (
85
+ self ,
86
+ grid_sizes : list [list [int ]],
87
+ strides : list [list [Tensor ]],
88
+ cell_anchors : list [torch .Tensor ],
89
+ ) -> list [Tensor ]:
85
90
anchors = []
86
- cell_anchors = self .cell_anchors
87
91
torch ._assert (cell_anchors is not None , "cell_anchors should not be None" )
88
92
torch ._assert (
89
93
len (grid_sizes ) == len (strides ) == len (cell_anchors ),
@@ -123,8 +127,8 @@ def forward(self, image_list: ImageList, feature_maps: list[Tensor]) -> list[Ten
123
127
]
124
128
for g in grid_sizes
125
129
]
126
- self .set_cell_anchors (dtype , device )
127
- anchors_over_all_feature_maps = self .grid_anchors (grid_sizes , strides )
130
+ cell_anchors = self .set_cell_anchors (dtype , device )
131
+ anchors_over_all_feature_maps = self .grid_anchors (grid_sizes , strides , cell_anchors )
128
132
anchors : list [list [torch .Tensor ]] = []
129
133
for _ in range (len (image_list .image_sizes )):
130
134
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps ]
0 commit comments