Skip to content

Commit bbff5b5

Browse files
committed
Revert "Revert "Minibench refactor (#10376)" (#10405)"
This reverts commit 98c2c53.
1 parent d31ef13 commit bbff5b5

File tree

10 files changed

+283
-240
lines changed

10 files changed

+283
-240
lines changed

extension/benchmark/android/benchmark/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench
4343

4444
### Generic model
4545
```
46-
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
46+
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
4747
--es model_dir /data/local/tmp/minibench
4848
```
4949

5050
### LLM
5151
```
52-
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
52+
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
5353
--es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin
5454
```
5555

extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ phases:
114114
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180
115115

116116
if [ -n "$BIN_FOUND" ]; then
117-
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
117+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
118118
--es "model_dir" "/data/local/tmp/minibench" \
119119
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
120120
elif [ -n "$MODEL_FOUND" ]; then
121-
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
121+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
122122
--es "model_dir" "/data/local/tmp/minibench" \
123123
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model"
124124
else

extension/benchmark/android/benchmark/app/build.gradle.kts

+9-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
plugins { id("com.android.application") }
9+
plugins { id("com.android.application")
10+
id("org.jetbrains.kotlin.android")
11+
}
1012

1113
android {
1214
namespace = "org.pytorch.minibench"
@@ -29,8 +31,11 @@ android {
2931
}
3032
}
3133
compileOptions {
32-
sourceCompatibility = JavaVersion.VERSION_1_8
33-
targetCompatibility = JavaVersion.VERSION_1_8
34+
sourceCompatibility = JavaVersion.VERSION_17
35+
targetCompatibility = JavaVersion.VERSION_17
36+
}
37+
kotlinOptions {
38+
jvmTarget = "17"
3439
}
3540
}
3641

@@ -40,6 +45,7 @@ dependencies {
4045
implementation("com.facebook.fbjni:fbjni:0.5.1")
4146
implementation("com.google.code.gson:gson:2.8.6")
4247
implementation("org.json:json:20250107")
48+
implementation("androidx.core:core-ktx:1.13.1")
4349
testImplementation("junit:junit:4.13.2")
4450
androidTestImplementation("androidx.test.ext:junit:1.2.1")
4551
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")

extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml

-8
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@
2121
</intent-filter>
2222
</activity>
2323

24-
<activity
25-
android:name=".LlmBenchmarkActivity"
26-
android:exported="true">
27-
<intent-filter>
28-
<action android:name="org.pytorch.minibench.BENCHMARK" />
29-
</intent-filter>
30-
</activity>
31-
3224
</application>
3325

3426
</manifest>

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

+80-85
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010

1111
import android.app.Activity;
1212
import android.content.Intent;
13-
import android.os.AsyncTask;
1413
import android.os.Bundle;
15-
import android.os.Debug;
14+
import android.os.Handler;
15+
import android.os.HandlerThread;
16+
import android.os.Looper;
1617
import android.system.ErrnoException;
1718
import android.system.Os;
1819
import com.google.gson.Gson;
@@ -21,12 +22,22 @@
2122
import java.io.IOException;
2223
import java.util.ArrayList;
2324
import java.util.Arrays;
24-
import java.util.Collections;
2525
import java.util.List;
26-
import java.util.stream.Collectors;
27-
import org.pytorch.executorch.Module;
2826

2927
public class BenchmarkActivity extends Activity {
28+
29+
File mModel;
30+
int mNumIter;
31+
int mNumWarmupIter;
32+
String mTokenizerPath;
33+
float mTemperature;
34+
String mPrompt;
35+
36+
HandlerThread mHandlerThread;
37+
BenchmarkHandler mHandler;
38+
39+
List<BenchmarkMetric> mResult;
40+
3041
@Override
3142
protected void onCreate(Bundle savedInstanceState) {
3243
super.onCreate(savedInstanceState);
@@ -47,95 +58,79 @@ protected void onCreate(Bundle savedInstanceState) {
4758

4859
int numIter = intent.getIntExtra("num_iter", 50);
4960
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
61+
String tokenizerPath = intent.getStringExtra("tokenizer_path");
62+
float temperature = intent.getFloatExtra("temperature", 0.8f);
63+
String prompt = intent.getStringExtra("prompt");
64+
65+
mModel = model;
66+
mNumIter = numIter;
67+
mNumWarmupIter = numWarmupIter;
68+
mTokenizerPath = tokenizerPath;
69+
mTemperature = temperature;
70+
mPrompt = prompt;
71+
if (mPrompt == null) {
72+
mPrompt = "The ultimate answer";
73+
}
74+
mResult = new ArrayList<>();
5075

51-
long pssIdle = Debug.getPss();
76+
mHandlerThread = new HandlerThread("ModelRunner");
77+
mHandlerThread.start();
78+
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);
5279

53-
// TODO: Format the string with a parsable format
54-
Stats stats = new Stats();
80+
mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
81+
}
5582

56-
new AsyncTask<Void, Void, Void>() {
57-
@Override
58-
protected Void doInBackground(Void... voids) {
83+
void writeResult() {
84+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
85+
Gson gson = new Gson();
86+
writer.write(gson.toJson(mResult));
87+
} catch (IOException e) {
88+
e.printStackTrace();
89+
} finally {
90+
finish();
91+
}
92+
}
93+
}
5994

60-
// Record the time it takes to load the model and the forward method
61-
stats.loadStart = System.nanoTime();
62-
Module module = Module.load(model.getPath());
63-
stats.errorCode = module.loadMethod("forward");
64-
stats.loadEnd = System.nanoTime();
95+
class BenchmarkHandler extends Handler {
96+
public static int MESSAGE_RUN_BENCHMARK = 1;
97+
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;
6598

66-
for (int i = 0; i < numWarmupIter; i++) {
67-
module.forward();
68-
}
99+
ModelRunner mModelRunner;
100+
BenchmarkActivity mBenchmarkActivity;
69101

70-
for (int i = 0; i < numIter; i++) {
71-
long start = System.nanoTime();
72-
module.forward();
73-
double forwardMs = (System.nanoTime() - start) * 1e-6;
74-
stats.latency.add(forwardMs);
75-
}
76-
return null;
77-
}
102+
LlmModelRunner mLlmModelRunner;
103+
LlmBenchmark mLlmBenchmark;
78104

79-
@Override
80-
protected void onPostExecute(Void aVoid) {
81-
82-
final BenchmarkMetric.BenchmarkModel benchmarkModel =
83-
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
84-
final List<BenchmarkMetric> results = new ArrayList<>();
85-
// The list of metrics we have atm includes:
86-
// Avg inference latency after N iterations
87-
// Currently the result has large variance from outliers, so only use
88-
// 80% samples in the middle (trimmean 0.2)
89-
Collections.sort(stats.latency);
90-
int resultSize = stats.latency.size();
91-
List<Double> usedLatencyResults =
92-
stats.latency.subList(resultSize / 10, resultSize * 9 / 10);
93-
94-
results.add(
95-
new BenchmarkMetric(
96-
benchmarkModel,
97-
"avg_inference_latency(ms)",
98-
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
99-
0.0f));
100-
results.add(
101-
new BenchmarkMetric(
102-
benchmarkModel,
103-
"trimmean_inference_latency(ms)",
104-
usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
105-
0.0f));
106-
// Model load time
107-
results.add(
108-
new BenchmarkMetric(
109-
benchmarkModel,
110-
"model_load_time(ms)",
111-
(stats.loadEnd - stats.loadStart) * 1e-6,
112-
0.0f));
113-
// Load status
114-
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
115-
// RAM PSS usage
116-
results.add(
117-
new BenchmarkMetric(
118-
benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));
119-
120-
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
121-
Gson gson = new Gson();
122-
writer.write(gson.toJson(results));
123-
} catch (IOException e) {
124-
e.printStackTrace();
125-
}
126-
}
127-
}.execute();
105+
public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
106+
super(looper);
107+
mModelRunner = new ModelRunner();
108+
mBenchmarkActivity = benchmarkActivity;
128109
}
129-
}
130-
131-
class Stats {
132-
long loadStart;
133-
long loadEnd;
134-
List<Double> latency = new ArrayList<>();
135-
int errorCode = 0;
136110

137111
@Override
138-
public String toString() {
139-
return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
112+
public void handleMessage(android.os.Message msg) {
113+
if (msg.what == MESSAGE_RUN_BENCHMARK) {
114+
mModelRunner.runBenchmark(
115+
mBenchmarkActivity.mModel,
116+
mBenchmarkActivity.mNumWarmupIter,
117+
mBenchmarkActivity.mNumIter,
118+
mBenchmarkActivity.mResult);
119+
120+
if (mBenchmarkActivity.mTokenizerPath == null) {
121+
mBenchmarkActivity.writeResult();
122+
} else {
123+
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
124+
}
125+
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
126+
mLlmBenchmark =
127+
new LlmBenchmark(
128+
mBenchmarkActivity,
129+
mBenchmarkActivity.mModel.getPath(),
130+
mBenchmarkActivity.mTokenizerPath,
131+
mBenchmarkActivity.mPrompt,
132+
mBenchmarkActivity.mTemperature,
133+
mBenchmarkActivity.mResult);
134+
}
140135
}
141136
}

0 commit comments

Comments
 (0)