Skip to content

Commit ccb6ec0

Browse files
authored
feat(chat): add cancel button in chat to stop generation (#36)
1 parent d1b6672 commit ccb6ec0

File tree

9 files changed

+191
-73
lines changed

9 files changed

+191
-73
lines changed

src/common/chat/cloudChat.ts

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { configuration } from "../utils/configuration";
1111
type Parameters = {
1212
temperature: number;
1313
n_predict: number;
14+
controller?: AbortController;
1415
};
1516

1617
export const sendChatRequestCloud = async (
@@ -45,6 +46,9 @@ export const sendChatRequestCloud = async (
4546
});
4647

4748
const stream = await model.pipe(parser).stream(messages, {
49+
configurable: {
50+
signal: parameters.controller,
51+
},
4852
maxConcurrency: 1,
4953
});
5054

src/common/chat/index.ts

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ export async function* chat(
2626
history: ChatMessage[],
2727
config?: {
2828
provideHighlightedText?: boolean;
29+
abortController: AbortController;
2930
}
3031
) {
3132
const loggerCompletion = logCompletion();
@@ -36,6 +37,7 @@ export async function* chat(
3637
n_predict: 4096,
3738
stop: [],
3839
temperature: 0.7,
40+
controller: config?.abortController,
3941
};
4042

4143
const { stopTask } = statusBar.startTask();

src/common/chat/localChat.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ export async function* sendChatRequestLocal(
5353
const startTime = performance.now();
5454

5555
let timings;
56-
for await (const chunk of llama(prompt, parametersForCompletion, { url })) {
56+
for await (const chunk of llama(prompt, parametersForCompletion, {
57+
url,
58+
controller: parameters.controller,
59+
})) {
5760
// @ts-ignore
5861
if (chunk.data) {
5962
// @ts-ignore

src/common/panel/chat.ts

+54-18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import { Disposable, Webview, window, Uri } from "vscode";
1+
import { Disposable, Webview, Uri } from "vscode";
22
import * as vscode from "vscode";
33
import { getUri } from "../utils/getUri";
44
import { getNonce } from "../utils/getNonce";
55
import { chat } from "../chat";
6+
import { ChatMessage } from "../prompt/promptChat";
67

78
export type MessageType =
89
| {
@@ -13,15 +14,16 @@ export type MessageType =
1314
}
1415
| {
1516
type: "e2w-response";
17+
id: string;
1618
command: string;
17-
messageId: string;
1819
done: boolean;
1920
data: any;
2021
};
2122

2223
export class ChatPanel implements vscode.WebviewViewProvider {
2324
private disposables: Disposable[] = [];
2425
private webview: Webview | undefined;
26+
private messageCallback: Record<string, any> = {};
2527

2628
constructor(private readonly extensionUri: vscode.Uri) {}
2729

@@ -94,26 +96,19 @@ export class ChatPanel implements vscode.WebviewViewProvider {
9496
private setWebviewMessageListener(webview: Webview) {
9597
webview.onDidReceiveMessage(
9698
async (message: any) => {
97-
const sendResponse = (messageToResponse: any, done: boolean) => {
98-
this.postMessage({
99-
type: "e2w-response",
100-
command: message.type,
101-
messageId: message.messageId,
102-
data: messageToResponse,
103-
done: done,
104-
});
105-
};
99+
if (message.type in this.messageCallback) {
100+
this.messageCallback[message.type]();
101+
return;
102+
}
106103
const type = message.type;
107-
const data = message.data;
108104

109105
switch (type) {
110106
case "sendMessage":
111-
for await (const message of chat(data, {
112-
provideHighlightedText: true,
113-
})) {
114-
sendResponse(message, false);
115-
}
116-
sendResponse("", true);
107+
await this.handleStartGeneration({
108+
chatMessage: message.data,
109+
messageId: message.messageId,
110+
messageType: message.type,
111+
});
117112
return;
118113
}
119114
},
@@ -122,6 +117,47 @@ export class ChatPanel implements vscode.WebviewViewProvider {
122117
);
123118
}
124119

120+
private addMessageListener(
121+
commandOrMessageId: string,
122+
callback: (message: any) => void
123+
) {
124+
this.messageCallback[commandOrMessageId] = callback;
125+
}
126+
127+
private async handleStartGeneration({
128+
messageId,
129+
messageType,
130+
chatMessage,
131+
}: {
132+
messageId: string;
133+
messageType: string;
134+
chatMessage: ChatMessage[];
135+
}) {
136+
const sendResponse = (messageToResponse: any, done: boolean) => {
137+
this.postMessage({
138+
type: "e2w-response",
139+
id: messageId,
140+
command: messageType,
141+
data: messageToResponse,
142+
done: done,
143+
});
144+
};
145+
const abortController = new AbortController();
146+
147+
this.addMessageListener("abort-generate", () => {
148+
abortController.abort();
149+
});
150+
151+
for await (const message of chat(chatMessage, {
152+
provideHighlightedText: true,
153+
abortController,
154+
})) {
155+
sendResponse(message, false);
156+
}
157+
158+
sendResponse("", true);
159+
}
160+
125161
public async sendMessageToWebview(
126162
command: MessageType["command"],
127163
data: MessageType["data"]

webviews/src/App.tsx

+3-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ export const App = () => {
1515
input,
1616
setInput,
1717
startNewChat,
18+
stop,
1819
} = useChat();
1920

2021
useMessageListener("startNewChat", () => {
@@ -50,14 +51,11 @@ export const App = () => {
5051
buttonEnd={
5152
<VSCodeButton
5253
appearance="icon"
53-
disabled={isLoading}
54-
onClick={handleSubmit}
54+
onClick={isLoading ? stop : handleSubmit}
5555
>
5656
<span
5757
className={`codicon ${
58-
isLoading
59-
? "codicon-loading codicon-modifier-spin codicon-modifier-disabled"
60-
: "codicon-send"
58+
isLoading ? "codicon-debug-stop" : "codicon-send"
6159
}`}
6260
></span>
6361
</VSCodeButton>

webviews/src/components/TextArea/index.tsx

+1-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ const TextArea = ({
5454
onSubmit();
5555
}
5656

57-
event.preventDefault(); // Prevents the addition of a new line in the text field
57+
event.preventDefault();
5858
}
5959
}}
6060
></textarea>
@@ -63,5 +63,4 @@ const TextArea = ({
6363
);
6464
};
6565

66-
// 24 42 61
6766
export default TextArea;

webviews/src/hooks/useChat.ts

+45-42
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,58 @@
1-
import { useCallback, useState } from "react";
1+
import { useCallback, useRef, useState } from "react";
22
import { randomMessageId } from "../utilities/messageId";
33
import { vscode } from "../utilities/vscode";
44

5+
export type ChatMessage = {
6+
role: string;
7+
content: string;
8+
chatMessageId: string;
9+
};
10+
511
export const useChat = () => {
6-
const [chatMessages, setChatMessages] = useState<
7-
{
8-
role: string;
9-
content: string;
10-
chatMessageId: string;
11-
}[]
12-
>([]);
12+
const [chatMessages, setChatMessages] = useState<ChatMessage[]>([]);
1313

1414
const [input, setInput] = useState("");
1515
const [isLoading, setIsLoading] = useState(false);
1616

17+
const abortController = useRef(new AbortController());
18+
19+
const sendMessage = async (chatHistoryLocal: ChatMessage[]) => {
20+
const messageId = randomMessageId();
21+
for await (const newMessage of vscode.startGeneration(chatHistoryLocal, {
22+
signal: abortController.current.signal,
23+
})) {
24+
setChatMessages((chatHistoryLocal) => {
25+
const messages = chatHistoryLocal.filter(
26+
(message) => message.chatMessageId !== messageId
27+
);
28+
29+
const currentChatMessage = chatHistoryLocal.find(
30+
(message) => message.chatMessageId === messageId
31+
);
32+
33+
return [
34+
...messages,
35+
{
36+
role: "ai",
37+
content: (currentChatMessage?.content || "") + newMessage,
38+
chatMessageId: messageId,
39+
},
40+
];
41+
});
42+
}
43+
setIsLoading(false);
44+
};
45+
1746
const handleSubmit = () => {
1847
if (isLoading) {
1948
return;
2049
}
2150
if (input === "") {
2251
return;
2352
}
53+
if (abortController.current.signal.aborted) {
54+
abortController.current = new AbortController();
55+
}
2456

2557
setChatMessages((value) => {
2658
const messageId = randomMessageId();
@@ -41,40 +73,10 @@ export const useChat = () => {
4173
setInput("");
4274
};
4375

44-
const sendMessage = async (chatHistoryLocal: any) => {
45-
const messageId = randomMessageId();
46-
await vscode.postMessageCallback(
47-
{
48-
type: "sendMessage",
49-
data: chatHistoryLocal,
50-
},
51-
(newMessage) => {
52-
setChatMessages((chatHistoryLocal) => {
53-
const messages = chatHistoryLocal.filter(
54-
(message) => message.chatMessageId !== messageId
55-
);
56-
57-
const currentChatMessage = chatHistoryLocal.find(
58-
(message) => message.chatMessageId === messageId
59-
);
60-
61-
if (newMessage.done) {
62-
setIsLoading(false);
63-
return chatHistoryLocal;
64-
}
65-
66-
return [
67-
...messages,
68-
{
69-
role: "ai",
70-
content: (currentChatMessage?.content || "") + newMessage.data,
71-
chatMessageId: messageId,
72-
},
73-
];
74-
});
75-
}
76-
);
77-
};
76+
const stop = useCallback(() => {
77+
abortController.current.abort();
78+
setIsLoading(false);
79+
}, [abortController]);
7880

7981
const startNewChat = useCallback(() => {
8082
setChatMessages([]);
@@ -87,5 +89,6 @@ export const useChat = () => {
8789
setInput,
8890
handleSubmit,
8991
startNewChat,
92+
stop,
9093
};
9194
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
export class Transform<T> {
2+
private open = true;
3+
private queue: T[] = [];
4+
private resolve: (() => void) | undefined;
5+
6+
async *stream(): AsyncGenerator<T> {
7+
this.open = true;
8+
9+
while (this.open) {
10+
if (this.queue.length) {
11+
yield this.queue.shift()!;
12+
continue;
13+
}
14+
15+
await new Promise<void>((resolveLocal) => {
16+
this.resolve = resolveLocal;
17+
});
18+
}
19+
}
20+
21+
push(data: T): void {
22+
this.queue.push(data);
23+
this.resolve?.();
24+
}
25+
26+
close(): void {
27+
this.open = false;
28+
this.resolve?.();
29+
}
30+
}

0 commit comments

Comments
 (0)