|
| 1 | +# /// script |
| 2 | +# requires-python = ">=3.11" |
| 3 | +# dependencies = [ |
| 4 | +# "fastapi", |
| 5 | +# "openai", |
| 6 | +# "pydantic", |
| 7 | +# "requests", |
| 8 | +# "uvicorn", |
| 9 | +# "typer", |
| 10 | +# ] |
| 11 | +# /// |
| 12 | +import json |
| 13 | +import openai |
| 14 | +from pydantic import BaseModel |
| 15 | +import requests |
| 16 | +import sys |
| 17 | +import typer |
| 18 | +from typing import Annotated, List, Optional |
| 19 | +import urllib |
| 20 | + |
| 21 | + |
| 22 | +class OpenAPIMethod: |
| 23 | + def __init__(self, url, name, descriptor, catalog): |
| 24 | + self.url = url |
| 25 | + self.__name__ = name |
| 26 | + |
| 27 | + assert 'post' in descriptor, 'Only POST methods are supported' |
| 28 | + post_descriptor = descriptor['post'] |
| 29 | + |
| 30 | + self.__doc__ = post_descriptor.get('description', '') |
| 31 | + parameters = post_descriptor.get('parameters', []) |
| 32 | + request_body = post_descriptor.get('requestBody') |
| 33 | + |
| 34 | + self.parameters = {p['name']: p for p in parameters} |
| 35 | + assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {url}, descriptor: {json.dumps(descriptor)})' |
| 36 | + |
| 37 | + self.body = None |
| 38 | + if request_body: |
| 39 | + assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {url}, descriptor: {json.dumps(descriptor)})' |
| 40 | + |
| 41 | + body_name = 'body' |
| 42 | + i = 2 |
| 43 | + while body_name in self.parameters: |
| 44 | + body_name = f'body{i}' |
| 45 | + i += 1 |
| 46 | + |
| 47 | + self.body = dict( |
| 48 | + name=body_name, |
| 49 | + required=request_body['required'], |
| 50 | + schema=request_body['content']['application/json']['schema'], |
| 51 | + ) |
| 52 | + |
| 53 | + self.parameters_schema = dict( |
| 54 | + type='object', |
| 55 | + properties={ |
| 56 | + **({ |
| 57 | + self.body['name']: self.body['schema'] |
| 58 | + } if self.body else {}), |
| 59 | + **{ |
| 60 | + name: param['schema'] |
| 61 | + for name, param in self.parameters.items() |
| 62 | + } |
| 63 | + }, |
| 64 | + components=catalog.get('components'), |
| 65 | + required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else []) |
| 66 | + ) |
| 67 | + |
| 68 | + def __call__(self, **kwargs): |
| 69 | + if self.body: |
| 70 | + body = kwargs.pop(self.body['name'], None) |
| 71 | + if self.body['required']: |
| 72 | + assert body is not None, f'Missing required body parameter: {self.body["name"]}' |
| 73 | + else: |
| 74 | + body = None |
| 75 | + |
| 76 | + query_params = {} |
| 77 | + for name, param in self.parameters.items(): |
| 78 | + value = kwargs.pop(name, None) |
| 79 | + if param['required']: |
| 80 | + assert value is not None, f'Missing required parameter: {name}' |
| 81 | + |
| 82 | + assert param['in'] == 'query', 'Only query parameters are supported' |
| 83 | + query_params[name] = value |
| 84 | + |
| 85 | + params = "&".join(f"{name}={urllib.parse.quote(value)}" for name, value in query_params.items()) |
| 86 | + url = f'{self.url}?{params}' |
| 87 | + response = requests.post(url, json=body) |
| 88 | + response.raise_for_status() |
| 89 | + response_json = response.json() |
| 90 | + |
| 91 | + return response_json |
| 92 | + |
| 93 | + |
| 94 | +def main( |
| 95 | + goal: Annotated[str, typer.Option()], |
| 96 | + api_key: Optional[str] = None, |
| 97 | + tool_endpoint: Optional[List[str]] = None, |
| 98 | + format: Annotated[Optional[str], typer.Option(help="The output format: either a Python type (e.g. 'float' or a Pydantic model defined in one of the tool files), or a JSON schema, e.g. '{\"format\": \"date\"}'")] = None, |
| 99 | + max_iterations: Optional[int] = 10, |
| 100 | + parallel_calls: Optional[bool] = False, |
| 101 | + verbose: bool = False, |
| 102 | + # endpoint: Optional[str] = None, |
| 103 | + endpoint: str = "http://localhost:8080/v1/", |
| 104 | +): |
| 105 | + |
| 106 | + openai.api_key = api_key |
| 107 | + openai.base_url = endpoint |
| 108 | + |
| 109 | + tool_map = {} |
| 110 | + tools = [] |
| 111 | + |
| 112 | + for url in (tool_endpoint or []): |
| 113 | + assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}' |
| 114 | + |
| 115 | + catalog_url = f'{url}/openapi.json' |
| 116 | + catalog_response = requests.get(catalog_url) |
| 117 | + catalog_response.raise_for_status() |
| 118 | + catalog = catalog_response.json() |
| 119 | + |
| 120 | + for path, descriptor in catalog['paths'].items(): |
| 121 | + fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog) |
| 122 | + tool_map[fn.__name__] = fn |
| 123 | + if verbose: |
| 124 | + sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(fn.parameters_schema, indent=2)}\n') |
| 125 | + tools.append(dict( |
| 126 | + type="function", |
| 127 | + function=dict( |
| 128 | + name=fn.__name__, |
| 129 | + description=fn.__doc__ or '', |
| 130 | + parameters=fn.parameters_schema, |
| 131 | + ) |
| 132 | + ) |
| 133 | + ) |
| 134 | + |
| 135 | + sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n') |
| 136 | + |
| 137 | + messages = [ |
| 138 | + dict( |
| 139 | + role="user", |
| 140 | + content=goal, |
| 141 | + ) |
| 142 | + ] |
| 143 | + |
| 144 | + i = 0 |
| 145 | + while (max_iterations is None or i < max_iterations): |
| 146 | + |
| 147 | + response = openai.chat.completions.create( |
| 148 | + model="gpt-4o", |
| 149 | + messages=messages, |
| 150 | + tools=tools, |
| 151 | + ) |
| 152 | + |
| 153 | + if verbose: |
| 154 | + sys.stderr.write(f'# RESPONSE: {response}\n') |
| 155 | + |
| 156 | + assert len(response.choices) == 1 |
| 157 | + choice = response.choices[0] |
| 158 | + |
| 159 | + content = choice.message.content |
| 160 | + if choice.finish_reason == "tool_calls": |
| 161 | + messages.append(choice.message) |
| 162 | + for tool_call in choice.message.tool_calls: |
| 163 | + if content: |
| 164 | + print(f'💭 {content}') |
| 165 | + |
| 166 | + args = json.loads(tool_call.function.arguments) |
| 167 | + pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})' |
| 168 | + sys.stdout.write(f'⚙️ {pretty_call}') |
| 169 | + sys.stdout.flush() |
| 170 | + tool_result = tool_map[tool_call.function.name](**args) |
| 171 | + sys.stdout.write(f" → {tool_result}\n") |
| 172 | + messages.append(dict( |
| 173 | + tool_call_id=tool_call.id, |
| 174 | + role="tool", |
| 175 | + name=tool_call.function.name, |
| 176 | + content=f'{tool_result}', |
| 177 | + # content=f'{pretty_call} = {tool_result}', |
| 178 | + )) |
| 179 | + else: |
| 180 | + assert content |
| 181 | + print(content) |
| 182 | + |
| 183 | + i += 1 |
| 184 | + |
| 185 | + if max_iterations is not None: |
| 186 | + raise Exception(f"Failed to get a valid response after {max_iterations} tool calls") |
| 187 | + |
| 188 | +if __name__ == '__main__': |
| 189 | + typer.run(main) |
0 commit comments