Skip to content

Commit f6a0e21

Browse files
authored
Initial support for multiple targets. (#747)
1 parent b99bfc2 commit f6a0e21

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

promptsource/app.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,9 @@ def show_text(t, width=WIDTH, with_markdown=False):
413413
st.markdown("###### Input template")
414414
show_jinja(splitted_template[0].strip())
415415
if len(splitted_template) > 1:
416-
st.markdown("###### Target template")
417-
show_jinja(splitted_template[1].strip())
416+
for splitted_target in splitted_template[1:]:
417+
st.markdown("###### Target template")
418+
show_jinja(splitted_target.strip())
418419
st.markdown("***")
419420

420421
#
@@ -437,8 +438,9 @@ def show_text(t, width=WIDTH, with_markdown=False):
437438
st.write("Input")
438439
show_text(prompt[0])
439440
if len(prompt) > 1:
440-
st.write("Target")
441-
show_text(prompt[1])
441+
for target in prompt[1]:
442+
st.write("Target")
443+
show_text(target)
442444
st.markdown("***")
443445
else: # mode = Sourcing
444446
st.markdown("## Prompt Creator")
@@ -627,8 +629,9 @@ def show_text(t, width=WIDTH, with_markdown=False):
627629
st.write("Input")
628630
show_text(prompt[0], width=40)
629631
if len(prompt) > 1:
630-
st.write("Target")
631-
show_text(prompt[1], width=40)
632+
for target in prompt[1]:
633+
st.write("Target")
634+
show_text(target, width=40)
632635

633636
#
634637
# Must sync state at end

promptsource/templates.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,14 @@ def get_fixed_answer_choices_list(self):
153153
else:
154154
return None
155155

156-
def apply(self, example, truncate=True, highlight_variables=False):
156+
def apply(self, example, truncate=True, highlight_variables=False) -> Tuple[str, List[str]]:
157157
"""
158158
Creates a prompt by applying this template to an example
159159
160160
:param example: the dataset example to create a prompt for
161161
:param truncate: if True, example fields will be truncated to TEXT_VAR_LENGTH chars
162162
:param highlight_variables: highlight the added variables
163-
:return: tuple of 2 strings, for prompt and output
163+
:return: tuple of a string and a list of strings, for input and targets
164164
"""
165165
jinja = self.jinja
166166

@@ -189,7 +189,11 @@ def apply(self, example, truncate=True, highlight_variables=False):
189189

190190
# Splits on the separator, and then replaces back any occurrences of the
191191
# separator in the original example
192-
return [self._unescape_pipe(part).strip() for part in rendered_example.split("|||")]
192+
parts = [self._unescape_pipe(part).strip() for part in rendered_example.split("|||")]
193+
if len(parts) < 2:
194+
raise ValueError("Prompt did not produce an input and at least one target.")
195+
196+
return parts[0], parts[1:]
193197

194198
pipe_protector = "3ed2dface8203c4c9dfb1a5dc58e41e0"
195199

test/test_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_dataset(dataset):
9898
# Check 2: Prompt/output separator present?
9999
if "|||" not in template.jinja:
100100
raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
101-
f"with uuid {template.get_id()} has no prompt/output separator.")
101+
f"with uuid {template.get_id()} has no input/target separator.")
102102

103103
# Check 3: Unique names and templates?
104104
if template.get_name() in template_name_set:

0 commit comments

Comments
 (0)