Skip to content

Revert "Minibench refactor (#10376)" #10405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extension/benchmark/android/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench

### Generic model
```
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
--es model_dir /data/local/tmp/minibench
```

### LLM
```
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
--es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ phases:
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180

if [ -n "$BIN_FOUND" ]; then
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
elif [ -n "$MODEL_FOUND" ]; then
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model"
else
Expand Down
12 changes: 3 additions & 9 deletions extension/benchmark/android/benchmark/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

plugins { id("com.android.application")
id("org.jetbrains.kotlin.android")
}
plugins { id("com.android.application") }

android {
namespace = "org.pytorch.minibench"
Expand All @@ -31,11 +29,8 @@ android {
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlinOptions {
jvmTarget = "17"
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}
}

Expand All @@ -45,7 +40,6 @@ dependencies {
implementation("com.facebook.fbjni:fbjni:0.5.1")
implementation("com.google.code.gson:gson:2.8.6")
implementation("org.json:json:20250107")
implementation("androidx.core:core-ktx:1.13.1")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
</intent-filter>
</activity>

<activity
android:name=".LlmBenchmarkActivity"
android:exported="true">
<intent-filter>
<action android:name="org.pytorch.minibench.BENCHMARK" />
</intent-filter>
</activity>

</application>

</manifest>
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@

import android.app.Activity;
import android.content.Intent;
import android.os.AsyncTask;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Looper;
import android.os.Debug;
import android.system.ErrnoException;
import android.system.Os;
import com.google.gson.Gson;
Expand All @@ -22,22 +21,12 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.pytorch.executorch.Module;

public class BenchmarkActivity extends Activity {

File mModel;
int mNumIter;
int mNumWarmupIter;
String mTokenizerPath;
float mTemperature;
String mPrompt;

HandlerThread mHandlerThread;
BenchmarkHandler mHandler;

List<BenchmarkMetric> mResult;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
Expand All @@ -58,79 +47,95 @@ protected void onCreate(Bundle savedInstanceState) {

int numIter = intent.getIntExtra("num_iter", 50);
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
String tokenizerPath = intent.getStringExtra("tokenizer_path");
float temperature = intent.getFloatExtra("temperature", 0.8f);
String prompt = intent.getStringExtra("prompt");

mModel = model;
mNumIter = numIter;
mNumWarmupIter = numWarmupIter;
mTokenizerPath = tokenizerPath;
mTemperature = temperature;
mPrompt = prompt;
if (mPrompt == null) {
mPrompt = "The ultimate answer";
}
mResult = new ArrayList<>();

mHandlerThread = new HandlerThread("ModelRunner");
mHandlerThread.start();
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);
long pssIdle = Debug.getPss();

mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
}
// TODO: Format the string with a parsable format
Stats stats = new Stats();

void writeResult() {
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mResult));
} catch (IOException e) {
e.printStackTrace();
} finally {
finish();
}
}
}
new AsyncTask<Void, Void, Void>() {
@Override
protected Void doInBackground(Void... voids) {

class BenchmarkHandler extends Handler {
public static int MESSAGE_RUN_BENCHMARK = 1;
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;
// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();

ModelRunner mModelRunner;
BenchmarkActivity mBenchmarkActivity;
for (int i = 0; i < numWarmupIter; i++) {
module.forward();
}

LlmModelRunner mLlmModelRunner;
LlmBenchmark mLlmBenchmark;
for (int i = 0; i < numIter; i++) {
long start = System.nanoTime();
module.forward();
double forwardMs = (System.nanoTime() - start) * 1e-6;
stats.latency.add(forwardMs);
}
return null;
}

public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
super(looper);
mModelRunner = new ModelRunner();
mBenchmarkActivity = benchmarkActivity;
@Override
protected void onPostExecute(Void aVoid) {

final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
// Currently the result has large variance from outliers, so only use
// 80% samples in the middle (trimmean 0.2)
Collections.sort(stats.latency);
int resultSize = stats.latency.size();
List<Double> usedLatencyResults =
stats.latency.subList(resultSize / 10, resultSize * 9 / 10);

results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
results.add(
new BenchmarkMetric(
benchmarkModel,
"trimmean_inference_latency(ms)",
usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ms)",
(stats.loadEnd - stats.loadStart) * 1e-6,
0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
// RAM PSS usage
results.add(
new BenchmarkMetric(
benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}
}.execute();
}
}

class Stats {
long loadStart;
long loadEnd;
List<Double> latency = new ArrayList<>();
int errorCode = 0;

@Override
public void handleMessage(android.os.Message msg) {
if (msg.what == MESSAGE_RUN_BENCHMARK) {
mModelRunner.runBenchmark(
mBenchmarkActivity.mModel,
mBenchmarkActivity.mNumWarmupIter,
mBenchmarkActivity.mNumIter,
mBenchmarkActivity.mResult);

if (mBenchmarkActivity.mTokenizerPath == null) {
mBenchmarkActivity.writeResult();
} else {
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
}
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
mLlmBenchmark =
new LlmBenchmark(
mBenchmarkActivity,
mBenchmarkActivity.mModel.getPath(),
mBenchmarkActivity.mTokenizerPath,
mBenchmarkActivity.mPrompt,
mBenchmarkActivity.mTemperature,
mBenchmarkActivity.mResult);
}
public String toString() {
return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
}
}
Loading
Loading