@@ -14,9 +14,10 @@ Requirements: - python >= 3.7
14
14
We highly recommend CUDA when using torchRec. If using CUDA: - cuda >=
15
15
11.0
16
16
17
+ .. Should these be updated?
17
18
.. code :: python
18
19
19
- # install conda to make installying pytorch with cudatoolkit 11.3 easier.
20
+ # install conda to make installying pytorch with cudatoolkit 11.3 easier.
20
21
! sudo rm Miniconda3- py37_4.9.2- Linux- x86_64.sh Miniconda3- py37_4.9.2- Linux- x86_64.sh.*
21
22
! sudo wget https:// repo.anaconda.com/ miniconda/ Miniconda3- py37_4.9.2- Linux- x86_64.sh
22
23
! sudo chmod + x Miniconda3- py37_4.9.2- Linux- x86_64.sh
@@ -209,7 +210,7 @@ embedding table placement using planner and generate sharded model using
209
210
)
210
211
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
211
212
plan: ShardingPlan = planner.collective_plan(module, sharders, pg)
212
-
213
+
213
214
sharded_model = DistributedModelParallel(
214
215
module,
215
216
env = ShardingEnv.from_process_group(pg),
@@ -230,7 +231,7 @@ ranks.
230
231
.. code :: python
231
232
232
233
import multiprocess
233
-
234
+
234
235
def spmd_sharing_simulation (
235
236
sharding_type : ShardingType = ShardingType.TABLE_WISE ,
236
237
world_size = 2 ,
@@ -250,7 +251,7 @@ ranks.
250
251
)
251
252
p.start()
252
253
processes.append(p)
253
-
254
+
254
255
for p in processes:
255
256
p.join()
256
257
assert 0 == p.exitcode
@@ -329,4 +330,3 @@ With data parallel, we will repeat the tables for all devices.
329
330
330
331
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}
331
332
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}
332
-
0 commit comments