Skip to content

Commit 897c165

Browse files
committed
Adding count token impl
1 parent 02c4ba7 commit 897c165

File tree

5 files changed

+103
-6
lines changed

5 files changed

+103
-6
lines changed

packages/vertexai/src/methods/chrome-adapter.test.ts

+52
Original file line numberDiff line numberDiff line change
@@ -307,4 +307,56 @@ describe('ChromeAdapter', () => {
307307
});
308308
});
309309
});
310+
describe('countTokens', () => {
311+
it('counts tokens from a singular input', async () => {
312+
const inputText = 'first';
313+
const expectedCount = 10;
314+
const onDeviceParams = {
315+
systemPrompt: 'be yourself'
316+
} as LanguageModelCreateOptions;
317+
318+
// setting up stubs
319+
const languageModelProvider = {
320+
create: () => Promise.resolve({})
321+
} as LanguageModel;
322+
const languageModel = {
323+
measureInputUsage: _i => Promise.resolve(123)
324+
} as LanguageModel;
325+
const createStub = stub(languageModelProvider, 'create').resolves(
326+
languageModel
327+
);
328+
// overrides impl with stub method
329+
const measureInputUsageStub = stub(
330+
languageModel,
331+
'measureInputUsage'
332+
).resolves(expectedCount);
333+
334+
const adapter = new ChromeAdapter(
335+
languageModelProvider,
336+
'prefer_on_device',
337+
onDeviceParams
338+
);
339+
const countTokenRequest = {
340+
contents: [{ role: 'user', parts: [{ text: inputText }] }]
341+
} as GenerateContentRequest;
342+
const response = await adapter.countTokens(countTokenRequest);
343+
// Asserts initialization params are proxied.
344+
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
345+
// Asserts Vertex input type is mapped to Chrome type.
346+
expect(measureInputUsageStub).to.have.been.calledOnceWith([
347+
{
348+
role: 'user',
349+
content: [
350+
{
351+
type: 'text',
352+
content: inputText
353+
}
354+
]
355+
}
356+
]);
357+
expect(await response.json()).to.deep.equal({
358+
totalTokens: expectedCount
359+
});
360+
});
361+
});
310362
});

packages/vertexai/src/methods/chrome-adapter.ts

+16
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import {
1919
Content,
20+
CountTokensRequest,
2021
GenerateContentRequest,
2122
InferenceMode,
2223
Part,
@@ -117,6 +118,21 @@ export class ChromeAdapter {
117118
} as Response;
118119
}
119120

121+
async countTokens(request: CountTokensRequest): Promise<Response> {
122+
// TODO: Check if the request contains an image, and if so, throw.
123+
const session = await this.createSession(
124+
// TODO: normalize on-device params during construction.
125+
this.onDeviceParams || {}
126+
);
127+
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
128+
const tokenCount = await session.measureInputUsage(messages);
129+
return {
130+
json: async () => ({
131+
totalTokens: tokenCount
132+
})
133+
} as Response;
134+
}
135+
120136
/**
121137
* Asserts inference for the given request can be performed by an on-device model.
122138
*/

packages/vertexai/src/methods/count-tokens.test.ts

+13-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import { countTokens } from './count-tokens';
2525
import { CountTokensRequest } from '../types';
2626
import { ApiSettings } from '../types/internal';
2727
import { Task } from '../requests/request';
28+
import { ChromeAdapter } from './chrome-adapter';
2829

2930
use(sinonChai);
3031
use(chaiAsPromised);
@@ -55,7 +56,8 @@ describe('countTokens()', () => {
5556
const result = await countTokens(
5657
fakeApiSettings,
5758
'model',
58-
fakeRequestParams
59+
fakeRequestParams,
60+
new ChromeAdapter()
5961
);
6062
expect(result.totalTokens).to.equal(6);
6163
expect(result.totalBillableCharacters).to.equal(16);
@@ -81,7 +83,8 @@ describe('countTokens()', () => {
8183
const result = await countTokens(
8284
fakeApiSettings,
8385
'model',
84-
fakeRequestParams
86+
fakeRequestParams,
87+
new ChromeAdapter()
8588
);
8689
expect(result.totalTokens).to.equal(1837);
8790
expect(result.totalBillableCharacters).to.equal(117);
@@ -109,7 +112,8 @@ describe('countTokens()', () => {
109112
const result = await countTokens(
110113
fakeApiSettings,
111114
'model',
112-
fakeRequestParams
115+
fakeRequestParams,
116+
new ChromeAdapter()
113117
);
114118
expect(result.totalTokens).to.equal(258);
115119
expect(result).to.not.have.property('totalBillableCharacters');
@@ -135,7 +139,12 @@ describe('countTokens()', () => {
135139
json: mockResponse.json
136140
} as Response);
137141
await expect(
138-
countTokens(fakeApiSettings, 'model', fakeRequestParams)
142+
countTokens(
143+
fakeApiSettings,
144+
'model',
145+
fakeRequestParams,
146+
new ChromeAdapter()
147+
)
139148
).to.be.rejectedWith(/404.*not found/);
140149
expect(mockFetch).to.be.called;
141150
});

packages/vertexai/src/methods/count-tokens.ts

+16-1
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import {
2222
} from '../types';
2323
import { Task, makeRequest } from '../requests/request';
2424
import { ApiSettings } from '../types/internal';
25+
import { ChromeAdapter } from './chrome-adapter';
2526

26-
export async function countTokens(
27+
export async function countTokensOnCloud(
2728
apiSettings: ApiSettings,
2829
model: string,
2930
params: CountTokensRequest,
@@ -39,3 +40,17 @@ export async function countTokens(
3940
);
4041
return response.json();
4142
}
43+
44+
export async function countTokens(
45+
apiSettings: ApiSettings,
46+
model: string,
47+
params: CountTokensRequest,
48+
chromeAdapter: ChromeAdapter,
49+
requestOptions?: RequestOptions
50+
): Promise<CountTokensResponse> {
51+
if (await chromeAdapter.isAvailable(params)) {
52+
return (await chromeAdapter.countTokens(params)).json();
53+
}
54+
55+
return countTokensOnCloud(apiSettings, model, params, requestOptions);
56+
}

packages/vertexai/src/models/generative-model.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ export class GenerativeModel extends VertexAIModel {
153153
request: CountTokensRequest | string | Array<string | Part>
154154
): Promise<CountTokensResponse> {
155155
const formattedParams = formatGenerateContentInput(request);
156-
return countTokens(this._apiSettings, this.model, formattedParams);
156+
return countTokens(
157+
this._apiSettings,
158+
this.model,
159+
formattedParams,
160+
this.chromeAdapter
161+
);
157162
}
158163
}

0 commit comments

Comments
 (0)