2
2
import os
3
3
import random
4
4
import unittest
5
+ import urllib
5
6
6
7
import numpy as np
7
8
import requests
23
24
from torch .export import export , ExportedProgram
24
25
from torch .utils ._pytree import tree_flatten
25
26
26
- os .environ ["https_proxy" ] = "http://fwdproxy:8080"
27
+ proxies = {
28
+ "http" : "http://fwdproxy:8080" ,
29
+ "https" : "http://fwdproxy:8080" ,
30
+ }
27
31
28
32
29
33
def compute_sqnr (x : torch .Tensor , y : torch .Tensor ) -> float :
@@ -38,7 +42,12 @@ def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
38
42
39
43
40
44
def read_mp3_from_url (url ):
41
- response = requests .get (url )
45
+ try :
46
+ response = requests .get (url )
47
+ except :
48
+ # FB-only hack, need to use a forwarding proxy to get url
49
+ response = requests .get (url , proxies = proxies )
50
+
42
51
response .raise_for_status () # Ensure request is successful
43
52
audio_stream = io .BytesIO (response .content )
44
53
waveform , sample_rate = torchaudio .load (audio_stream , format = "mp3" )
@@ -68,7 +77,13 @@ def seed_all(seed):
68
77
seed_all (42424242 )
69
78
70
79
if mimi_weight is None :
71
- mimi_weight = hf_hub_download (hf_repo , loaders .MIMI_NAME )
80
+ try :
81
+ mimi_weight = hf_hub_download (hf_repo , loaders .MIMI_NAME )
82
+ except :
83
+ mimi_weight = hf_hub_download (
84
+ hf_repo , loaders .MIMI_NAME , proxies = proxies
85
+ )
86
+
72
87
cls .mimi = loaders .get_mimi (mimi_weight , device )
73
88
cls .device = device
74
89
cls .sample_pcm , cls .sample_sr = read_mp3_from_url (
0 commit comments