Skip to content

Commit a922181

Browse files
authored
Make ci happy (#91)
This PR do these things - lint all code use `black` and `isort` - delete all unused imports to make code clean - use black in GItHub Actions - add doc to these when contribute @csunny
2 parents 642b153 + 49d6eb1 commit a922181

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1096
-814
lines changed

.github/workflows/pylint.yml

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
name: Pylint
22

3-
on: [push]
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
workflow_dispatch:
9+
10+
concurrency:
11+
group: ${{ github.event.number || github.run_id }}
12+
cancel-in-progress: true
413

514
jobs:
615
build:
@@ -17,7 +26,7 @@ jobs:
1726
- name: Install dependencies
1827
run: |
1928
python -m pip install --upgrade pip
20-
pip install pylint
21-
- name: Analysing the code with pylint
29+
pip install -U black isort
30+
- name: check the code lint
2231
run: |
23-
pylint $(git ls-files '*.py')
32+
black . --check

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ The achievements of this project are thanks to the technical community, especial
215215
- [ChatGLM](https://github.com/THUDM/ChatGLM-6B) as the base model
216216
- [llama_index](https://github.com/jerryjliu/llama_index) for enhancing database-related knowledge using [in-context learning](https://arxiv.org/abs/2301.00234) based on existing knowledge bases.
217217

218+
## Contribution
219+
220+
- Please run `black .` before submitting the code.
221+
218222
<!-- GITCONTRIBUTOR_START -->
219223

220224
## Contributors

README.zh.md

+4
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ python tools/knowledge_init.py
218218
- [ChatGLM](https://github.com/THUDM/ChatGLM-6B) 基础模型
219219
- [llama-index](https://github.com/jerryjliu/llama_index) 基于现有知识库进行[In-Context Learning](https://arxiv.org/abs/2301.00234)来对其进行数据库相关知识的增强。
220220

221+
# 贡献
222+
223+
- 提交代码前请先执行 `black .`
224+
221225
<!-- GITCONTRIBUTOR_START -->
222226

223227
## 贡献者

examples/app.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,63 @@
22
# -*- coding:utf-8 -*-
33

44
import gradio as gr
5-
from langchain.agents import (
6-
load_tools,
7-
initialize_agent,
8-
AgentType
9-
)
10-
from pilot.model.vicuna_llm import VicunaRequestLLM, VicunaEmbeddingLLM
11-
from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext
5+
from langchain.agents import AgentType, initialize_agent, load_tools
126
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
13-
from llama_index import Document, GPTSimpleVectorIndex
7+
from llama_index import (
8+
Document,
9+
GPTSimpleVectorIndex,
10+
LangchainEmbedding,
11+
LLMPredictor,
12+
ServiceContext,
13+
)
14+
15+
from pilot.model.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM
16+
1417

1518
def agent_demo():
1619
llm = VicunaRequestLLM()
1720

18-
tools = load_tools(['python_repl'], llm=llm)
19-
agent = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
20-
agent.run(
21-
"Write a SQL script that Query 'select count(1)!'"
21+
tools = load_tools(["python_repl"], llm=llm)
22+
agent = initialize_agent(
23+
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
2224
)
25+
agent.run("Write a SQL script that Query 'select count(1)!'")
26+
2327

2428
def knowledged_qa_demo(text_list):
2529
llm_predictor = LLMPredictor(llm=VicunaRequestLLM())
2630
hfemb = VicunaEmbeddingLLM()
2731
embed_model = LangchainEmbedding(hfemb)
2832
documents = [Document(t) for t in text_list]
2933

30-
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
31-
index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context)
34+
service_context = ServiceContext.from_defaults(
35+
llm_predictor=llm_predictor, embed_model=embed_model
36+
)
37+
index = GPTSimpleVectorIndex.from_documents(
38+
documents, service_context=service_context
39+
)
3240
return index
3341

3442

3543
def get_answer(q):
36-
base_knowledge = """ """
44+
base_knowledge = """ """
3745
text_list = [base_knowledge]
3846
index = knowledged_qa_demo(text_list)
3947
response = index.query(q)
4048
return response.response
4149

50+
4251
def get_similar(q):
4352
from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st
53+
4454
docsearch = knownledge_tovec_st("./datasets/plan.md")
4555
docs = docsearch.similarity_search_with_score(q, k=1)
4656

4757
for doc in docs:
48-
dc, s = doc
58+
dc, s = doc
4959
print(s)
50-
yield dc.page_content
60+
yield dc.page_content
61+
5162

5263
if __name__ == "__main__":
5364
# agent_demo()
@@ -58,8 +69,7 @@ def get_similar(q):
5869
text_input = gr.TextArea()
5970
text_output = gr.TextArea()
6071
text_button = gr.Button()
61-
72+
6273
text_button.click(get_similar, inputs=text_input, outputs=text_output)
6374

6475
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
65-

examples/embdserver.py

+10-15
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,29 @@
11
#!/usr/bin/env python3
22
# -*- coding:utf-8 -*-
33

4-
import requests
54
import json
6-
import time
7-
import uuid
85
import os
96
import sys
107
from urllib.parse import urljoin
8+
119
import gradio as gr
10+
import requests
1211

1312
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1413
sys.path.append(ROOT_PATH)
1514

1615

17-
from pilot.configs.config import Config
18-
from pilot.conversation import conv_qa_prompt_template, conv_templates
1916
from langchain.prompts import PromptTemplate
2017

18+
from pilot.configs.config import Config
19+
from pilot.conversation import conv_qa_prompt_template, conv_templates
2120

2221
llmstream_stream_path = "generate_stream"
2322

2423
CFG = Config()
2524

26-
def generate(query):
2725

26+
def generate(query):
2827
template_name = "conv_one_shot"
2928
state = conv_templates[template_name].copy()
3029

@@ -47,7 +46,7 @@ def generate(query):
4746
"prompt": prompt,
4847
"temperature": 1.0,
4948
"max_new_tokens": 1024,
50-
"stop": "###"
49+
"stop": "###",
5150
}
5251

5352
response = requests.post(
@@ -57,19 +56,18 @@ def generate(query):
5756
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
5857

5958
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
60-
6159
if chunk:
6260
data = json.loads(chunk.decode())
6361
if data["error_code"] == 0:
64-
6562
if "vicuna" in CFG.LLM_MODEL:
6663
output = data["text"][skip_echo_len:].strip()
6764
else:
6865
output = data["text"].strip()
6966

7067
state.messages[-1][-1] = output + "▌"
71-
yield(output)
72-
68+
yield (output)
69+
70+
7371
if __name__ == "__main__":
7472
print(CFG.LLM_MODEL)
7573
with gr.Blocks() as demo:
@@ -78,10 +76,7 @@ def generate(query):
7876
text_input = gr.TextArea()
7977
text_output = gr.TextArea()
8078
text_button = gr.Button("提交")
81-
8279

8380
text_button.click(generate, inputs=text_input, outputs=text_output)
8481

85-
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")
86-
87-
82+
demo.queue(concurrency_count=3).launch(server_name="0.0.0.0")

examples/gpt_index.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33

4-
import os
54
import logging
65
import sys
76

8-
from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex
7+
from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader
8+
99
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
1010
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
1111

1212
# read the document of data dir
1313
documents = SimpleDirectoryReader("data").load_data()
14-
# split the document to chunk, max token size=500, convert chunk to vector
14+
# split the document to chunk, max token size=500, convert chunk to vector
1515

1616
index = GPTSimpleVectorIndex(documents)
1717

1818
# save index
19-
index.save_to_disk("index.json")
19+
index.save_to_disk("index.json")

examples/gradio_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33

44
import gradio as gr
55

6+
67
def change_tab():
78
return gr.Tabs.update(selected=1)
89

10+
911
with gr.Blocks() as demo:
1012
with gr.Tabs() as tabs:
1113
with gr.TabItem("Train", id=0):
1214
t = gr.Textbox()
1315
with gr.TabItem("Inference", id=1):
1416
i = gr.Image()
15-
17+
1618
btn = gr.Button()
1719
btn.click(change_tab, None, tabs)
1820

19-
demo.launch()
21+
demo.launch()
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
2-
31
from pilot.source_embedding.csv_embedding import CSVEmbedding
42

53
# path = "/Users/chenketing/Downloads/share_ireserve双写数据异常2.xlsx"
@@ -8,6 +6,13 @@
86
vector_store_path = "your_path/"
97

108

11-
pdf_embedding = CSVEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "url", "vector_store_path": "vector_store_path"})
9+
pdf_embedding = CSVEmbedding(
10+
file_path=path,
11+
model_name=model_name,
12+
vector_store_config={
13+
"vector_store_name": "url",
14+
"vector_store_path": "vector_store_path",
15+
},
16+
)
1217
pdf_embedding.source_embedding()
13-
print("success")
18+
print("success")

examples/knowledge_embedding/pdf_embedding_test.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
vector_store_path = "your_path/"
77

88

9-
pdf_embedding = PDFEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "ob-pdf", "vector_store_path": vector_store_path})
9+
pdf_embedding = PDFEmbedding(
10+
file_path=path,
11+
model_name=model_name,
12+
vector_store_config={
13+
"vector_store_name": "ob-pdf",
14+
"vector_store_path": vector_store_path,
15+
},
16+
)
1017
pdf_embedding.source_embedding()
11-
print("success")
18+
print("success")

examples/knowledge_embedding/url_embedding_test.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
vector_store_path = "your_path"
66

77

8-
pdf_embedding = URLEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "url", "vector_store_path": "vector_store_path"})
8+
pdf_embedding = URLEmbedding(
9+
file_path=path,
10+
model_name=model_name,
11+
vector_store_config={
12+
"vector_store_name": "url",
13+
"vector_store_path": "vector_store_path",
14+
},
15+
)
916
pdf_embedding.source_embedding()
10-
print("success")
17+
print("success")

0 commit comments

Comments
 (0)