Skip to content

refactor: centralize json schema generation #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ partially support by LLMs like GPT.

To leverage this, configure the `#[With]` attribute on the method arguments of your tool:
```php
use PhpLlm\LlmChain\Chain\JsonSchema\Attribute\With;
use PhpLlm\LlmChain\Chain\ToolBox\Attribute\AsTool;
use PhpLlm\LlmChain\Chain\ToolBox\Attribute\ToolParameter;

#[AsTool('my_tool', 'Example tool with parameters requirements.')]
final class MyTool
Expand All @@ -230,7 +230,7 @@ final class MyTool
}
```

See attribute class [With](src/Chain/ToolBox/Attribute/With.php) for all available options.
See attribute class [With](src/Chain/JsonSchema/Attribute/With.php) for all available options.

> [!NOTE]
> Please be aware, that this is only converted in a JSON Schema for the LLM to respect, but not validated by LLM Chain.
Expand Down
3 changes: 2 additions & 1 deletion composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
"php": ">=8.2",
"oskarstark/enum-helper": "^1.5",
"phpdocumentor/reflection-docblock": "^5.4",
"phpstan/phpdoc-parser": "^2.1",
"psr/cache": "^3.0",
"psr/log": "^3.0",
"symfony/clock": "^6.4 || ^7.1",
"symfony/http-client": "^6.4 || ^7.1",
"symfony/property-access": "^6.4 || ^7.1",
"symfony/property-info": "^6.4 || ^7.1",
"symfony/serializer": "^6.4 || ^7.1",
"symfony/type-info": "^6.4 || ^7.1",
"symfony/type-info": "^7.2.3",
"symfony/uid": "^6.4 || ^7.1",
"webmozart/assert": "^1.11"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

declare(strict_types=1);

namespace PhpLlm\LlmChain\Chain\ToolBox\Attribute;
namespace PhpLlm\LlmChain\Chain\JsonSchema\Attribute;

use Webmozart\Assert\Assert;

Expand Down
49 changes: 49 additions & 0 deletions src/Chain/JsonSchema/DescriptionParser.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Chain\JsonSchema;

final readonly class DescriptionParser
{
public function getDescription(\ReflectionProperty|\ReflectionParameter $reflector): string
{
if ($reflector instanceof \ReflectionProperty) {
return $this->fromProperty($reflector);
}

return $this->fromParameter($reflector);
}

private function fromProperty(\ReflectionProperty $property): string
{
$comment = $property->getDocComment();

if (is_string($comment) && preg_match('/@var\s+[a-zA-Z\\\\]+\s+((.*)(?=\*)|.*)/', $comment, $matches)) {
return trim($matches[1]);
}

$class = $property->getDeclaringClass();
if ($class->hasMethod('__construct')) {
return $this->fromParameter(
new \ReflectionParameter([$class->getName(), '__construct'], $property->getName())
);
}

return '';
}

private function fromParameter(\ReflectionParameter $parameter): string
{
$comment = $parameter->getDeclaringFunction()->getDocComment();
if (!$comment) {
return '';
}

if (preg_match('/@param\s+\S+\s+\$'.preg_quote($parameter->getName(), '/').'\s+((.*)(?=\*)|.*)/', $comment, $matches)) {
return trim($matches[1]);
}

return '';
}
}
175 changes: 175 additions & 0 deletions src/Chain/JsonSchema/Factory.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Chain\JsonSchema;

use PhpLlm\LlmChain\Chain\JsonSchema\Attribute\With;
use Symfony\Component\TypeInfo\Type;
use Symfony\Component\TypeInfo\Type\BuiltinType;
use Symfony\Component\TypeInfo\Type\CollectionType;
use Symfony\Component\TypeInfo\Type\ObjectType;
use Symfony\Component\TypeInfo\TypeIdentifier;
use Symfony\Component\TypeInfo\TypeResolver\TypeResolver;

/**
* @phpstan-type JsonSchema array{
* type: 'object',
* properties: array<string, array{
* type: string,
* description: string,
* enum?: list<string>,
* const?: string|int|list<string>,
* pattern?: string,
* minLength?: int,
* maxLength?: int,
* minimum?: int,
* maximum?: int,
* multipleOf?: int,
* exclusiveMinimum?: int,
* exclusiveMaximum?: int,
* minItems?: int,
* maxItems?: int,
* uniqueItems?: bool,
* minContains?: int,
* maxContains?: int,
* required?: bool,
* minProperties?: int,
* maxProperties?: int,
* dependentRequired?: bool,
* }>,
* required: list<string>,
* additionalProperties: false,
* }
*/
final readonly class Factory
{
private TypeResolver $typeResolver;

public function __construct(
private DescriptionParser $descriptionParser = new DescriptionParser(),
?TypeResolver $typeResolver = null,
) {
$this->typeResolver = $typeResolver ?? TypeResolver::create();
}

/**
* @return JsonSchema|null
*/
public function buildParameters(string $className, string $methodName): ?array
{
$reflection = new \ReflectionMethod($className, $methodName);

return $this->convertTypes($reflection->getParameters());
}

/**
* @return JsonSchema|null
*/
public function buildProperties(string $className): ?array
{
$reflection = new \ReflectionClass($className);

return $this->convertTypes($reflection->getProperties());
}

/**
* @param list<\ReflectionProperty|\ReflectionParameter> $elements
*
* @return JsonSchema|null
*/
private function convertTypes(array $elements): ?array
{
if (0 === count($elements)) {
return null;
}

$result = [
'type' => 'object',
'properties' => [],
'required' => [],
'additionalProperties' => false,
];

foreach ($elements as $element) {
$name = $element->getName();
$type = $this->typeResolver->resolve($element);
$schema = $this->getTypeSchema($type);

if ($type->isNullable()) {
$schema['type'] = [$schema['type'], 'null'];
} else {
$result['required'][] = $name;
}

$description = $this->descriptionParser->getDescription($element);
if ('' !== $description) {
$schema['description'] = $description;
}

// Check for ToolParameter attributes
$attributes = $element->getAttributes(With::class);
if (count($attributes) > 0) {
$attributeState = array_filter((array) $attributes[0]->newInstance(), fn ($value) => null !== $value);
$schema = array_merge($schema, $attributeState);
}

$result['properties'][$name] = $schema;
}

return $result;
}

/**
* @return array<string, mixed>
*/
private function getTypeSchema(Type $type): array
{
switch (true) {
case $type->isIdentifiedBy(TypeIdentifier::INT):
return ['type' => 'integer'];

case $type->isIdentifiedBy(TypeIdentifier::FLOAT):
return ['type' => 'number'];

case $type->isIdentifiedBy(TypeIdentifier::BOOL):
return ['type' => 'boolean'];

case $type->isIdentifiedBy(TypeIdentifier::ARRAY):
assert($type instanceof CollectionType);
$collectionValueType = $type->getCollectionValueType();

if ($collectionValueType->isIdentifiedBy(TypeIdentifier::OBJECT)) {
assert($collectionValueType instanceof ObjectType);

return [
'type' => 'array',
'items' => $this->buildProperties($collectionValueType->getClassName()),
];
}

return [
'type' => 'array',
'items' => $this->getTypeSchema($collectionValueType),
];

case $type->isIdentifiedBy(TypeIdentifier::OBJECT):
if ($type instanceof BuiltinType) {
throw new \InvalidArgumentException('Cannot build schema from plain object type.');
}
assert($type instanceof ObjectType);
if (in_array($type->getClassName(), ['DateTime', 'DateTimeImmutable', 'DateTimeInterface'], true)) {
return ['type' => 'string', 'format' => 'date-time'];
} else {
// Recursively build the schema for an object type
return $this->buildProperties($type->getClassName());
}

// no break
case $type->isIdentifiedBy(TypeIdentifier::STRING):
default:
// Fallback to string for any unhandled types
return ['type' => 'string'];
}
}
}
6 changes: 4 additions & 2 deletions src/Chain/StructuredOutput/ResponseFormatFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

namespace PhpLlm\LlmChain\Chain\StructuredOutput;

use PhpLlm\LlmChain\Chain\JsonSchema\Factory;

use function Symfony\Component\String\u;

final readonly class ResponseFormatFactory implements ResponseFormatFactoryInterface
{
public function __construct(
private SchemaFactory $schemaFactory = new SchemaFactory(),
private Factory $schemaFactory = new Factory(),
) {
}

Expand All @@ -19,7 +21,7 @@ public function create(string $responseClass): array
'type' => 'json_schema',
'json_schema' => [
'name' => u($responseClass)->afterLast('\\')->toString(),
'schema' => $this->schemaFactory->buildSchema($responseClass),
'schema' => $this->schemaFactory->buildProperties($responseClass),
'strict' => true,
],
];
Expand Down
Loading