10
10
11
11
import android .app .Activity ;
12
12
import android .content .Intent ;
13
- import android .os .AsyncTask ;
14
13
import android .os .Bundle ;
15
- import android .os .Debug ;
14
+ import android .os .Handler ;
15
+ import android .os .HandlerThread ;
16
+ import android .os .Looper ;
16
17
import android .system .ErrnoException ;
17
18
import android .system .Os ;
18
19
import com .google .gson .Gson ;
21
22
import java .io .IOException ;
22
23
import java .util .ArrayList ;
23
24
import java .util .Arrays ;
24
- import java .util .Collections ;
25
25
import java .util .List ;
26
- import java .util .stream .Collectors ;
27
- import org .pytorch .executorch .Module ;
28
26
29
27
public class BenchmarkActivity extends Activity {
28
+
29
+ File mModel ;
30
+ int mNumIter ;
31
+ int mNumWarmupIter ;
32
+ String mTokenizerPath ;
33
+ float mTemperature ;
34
+ String mPrompt ;
35
+
36
+ HandlerThread mHandlerThread ;
37
+ BenchmarkHandler mHandler ;
38
+
39
+ List <BenchmarkMetric > mResult ;
40
+
30
41
@ Override
31
42
protected void onCreate (Bundle savedInstanceState ) {
32
43
super .onCreate (savedInstanceState );
@@ -47,95 +58,79 @@ protected void onCreate(Bundle savedInstanceState) {
47
58
48
59
int numIter = intent .getIntExtra ("num_iter" , 50 );
49
60
int numWarmupIter = intent .getIntExtra ("num_warm_up_iter" , 10 );
61
+ String tokenizerPath = intent .getStringExtra ("tokenizer_path" );
62
+ float temperature = intent .getFloatExtra ("temperature" , 0.8f );
63
+ String prompt = intent .getStringExtra ("prompt" );
64
+
65
+ mModel = model ;
66
+ mNumIter = numIter ;
67
+ mNumWarmupIter = numWarmupIter ;
68
+ mTokenizerPath = tokenizerPath ;
69
+ mTemperature = temperature ;
70
+ mPrompt = prompt ;
71
+ if (mPrompt == null ) {
72
+ mPrompt = "The ultimate answer" ;
73
+ }
74
+ mResult = new ArrayList <>();
50
75
51
- long pssIdle = Debug .getPss ();
76
+ mHandlerThread = new HandlerThread ("ModelRunner" );
77
+ mHandlerThread .start ();
78
+ mHandler = new BenchmarkHandler (mHandlerThread .getLooper (), this );
52
79
53
- // TODO: Format the string with a parsable format
54
- Stats stats = new Stats ();
80
+ mHandler . sendEmptyMessage ( BenchmarkHandler . MESSAGE_RUN_BENCHMARK );
81
+ }
55
82
56
- new AsyncTask <Void , Void , Void >() {
57
- @ Override
58
- protected Void doInBackground (Void ... voids ) {
83
+ void writeResult () {
84
+ try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
85
+ Gson gson = new Gson ();
86
+ writer .write (gson .toJson (mResult ));
87
+ } catch (IOException e ) {
88
+ e .printStackTrace ();
89
+ } finally {
90
+ finish ();
91
+ }
92
+ }
93
+ }
59
94
60
- // Record the time it takes to load the model and the forward method
61
- stats .loadStart = System .nanoTime ();
62
- Module module = Module .load (model .getPath ());
63
- stats .errorCode = module .loadMethod ("forward" );
64
- stats .loadEnd = System .nanoTime ();
95
+ class BenchmarkHandler extends Handler {
96
+ public static int MESSAGE_RUN_BENCHMARK = 1 ;
97
+ public static int MESSAGE_LLM_RUN_BENCHMARK = 2 ;
65
98
66
- for (int i = 0 ; i < numWarmupIter ; i ++) {
67
- module .forward ();
68
- }
99
+ ModelRunner mModelRunner ;
100
+ BenchmarkActivity mBenchmarkActivity ;
69
101
70
- for (int i = 0 ; i < numIter ; i ++) {
71
- long start = System .nanoTime ();
72
- module .forward ();
73
- double forwardMs = (System .nanoTime () - start ) * 1e-6 ;
74
- stats .latency .add (forwardMs );
75
- }
76
- return null ;
77
- }
102
+ LlmModelRunner mLlmModelRunner ;
103
+ LlmBenchmark mLlmBenchmark ;
78
104
79
- @ Override
80
- protected void onPostExecute (Void aVoid ) {
81
-
82
- final BenchmarkMetric .BenchmarkModel benchmarkModel =
83
- BenchmarkMetric .extractBackendAndQuantization (model .getName ().replace (".pte" , "" ));
84
- final List <BenchmarkMetric > results = new ArrayList <>();
85
- // The list of metrics we have atm includes:
86
- // Avg inference latency after N iterations
87
- // Currently the result has large variance from outliers, so only use
88
- // 80% samples in the middle (trimmean 0.2)
89
- Collections .sort (stats .latency );
90
- int resultSize = stats .latency .size ();
91
- List <Double > usedLatencyResults =
92
- stats .latency .subList (resultSize / 10 , resultSize * 9 / 10 );
93
-
94
- results .add (
95
- new BenchmarkMetric (
96
- benchmarkModel ,
97
- "avg_inference_latency(ms)" ,
98
- stats .latency .stream ().mapToDouble (l -> l ).average ().orElse (0.0f ),
99
- 0.0f ));
100
- results .add (
101
- new BenchmarkMetric (
102
- benchmarkModel ,
103
- "trimmean_inference_latency(ms)" ,
104
- usedLatencyResults .stream ().mapToDouble (l -> l ).average ().orElse (0.0f ),
105
- 0.0f ));
106
- // Model load time
107
- results .add (
108
- new BenchmarkMetric (
109
- benchmarkModel ,
110
- "model_load_time(ms)" ,
111
- (stats .loadEnd - stats .loadStart ) * 1e-6 ,
112
- 0.0f ));
113
- // Load status
114
- results .add (new BenchmarkMetric (benchmarkModel , "load_status" , stats .errorCode , 0 ));
115
- // RAM PSS usage
116
- results .add (
117
- new BenchmarkMetric (
118
- benchmarkModel , "ram_pss_usage(mb)" , (Debug .getPss () - pssIdle ) / 1024 , 0 ));
119
-
120
- try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
121
- Gson gson = new Gson ();
122
- writer .write (gson .toJson (results ));
123
- } catch (IOException e ) {
124
- e .printStackTrace ();
125
- }
126
- }
127
- }.execute ();
105
+ public BenchmarkHandler (Looper looper , BenchmarkActivity benchmarkActivity ) {
106
+ super (looper );
107
+ mModelRunner = new ModelRunner ();
108
+ mBenchmarkActivity = benchmarkActivity ;
128
109
}
129
- }
130
-
131
- class Stats {
132
- long loadStart ;
133
- long loadEnd ;
134
- List <Double > latency = new ArrayList <>();
135
- int errorCode = 0 ;
136
110
137
111
@ Override
138
- public String toString () {
139
- return "latency: " + latency .stream ().map (Object ::toString ).collect (Collectors .joining ("" ));
112
+ public void handleMessage (android .os .Message msg ) {
113
+ if (msg .what == MESSAGE_RUN_BENCHMARK ) {
114
+ mModelRunner .runBenchmark (
115
+ mBenchmarkActivity .mModel ,
116
+ mBenchmarkActivity .mNumWarmupIter ,
117
+ mBenchmarkActivity .mNumIter ,
118
+ mBenchmarkActivity .mResult );
119
+
120
+ if (mBenchmarkActivity .mTokenizerPath == null ) {
121
+ mBenchmarkActivity .writeResult ();
122
+ } else {
123
+ this .sendEmptyMessage (MESSAGE_LLM_RUN_BENCHMARK );
124
+ }
125
+ } else if (msg .what == MESSAGE_LLM_RUN_BENCHMARK ) {
126
+ mLlmBenchmark =
127
+ new LlmBenchmark (
128
+ mBenchmarkActivity ,
129
+ mBenchmarkActivity .mModel .getPath (),
130
+ mBenchmarkActivity .mTokenizerPath ,
131
+ mBenchmarkActivity .mPrompt ,
132
+ mBenchmarkActivity .mTemperature ,
133
+ mBenchmarkActivity .mResult );
134
+ }
140
135
}
141
136
}
0 commit comments