Skip to content

Support for multiple targets. #747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions promptsource/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,9 @@ def show_text(t, width=WIDTH, with_markdown=False):
st.markdown("###### Input template")
show_jinja(splitted_template[0].strip())
if len(splitted_template) > 1:
st.markdown("###### Target template")
show_jinja(splitted_template[1].strip())
for splitted_target in splitted_template[1:]:
st.markdown("###### Target template")
show_jinja(splitted_target.strip())
st.markdown("***")

#
Expand All @@ -437,8 +438,9 @@ def show_text(t, width=WIDTH, with_markdown=False):
st.write("Input")
show_text(prompt[0])
if len(prompt) > 1:
st.write("Target")
show_text(prompt[1])
for target in prompt[1]:
st.write("Target")
show_text(target)
st.markdown("***")
else: # mode = Sourcing
st.markdown("## Prompt Creator")
Expand Down Expand Up @@ -627,8 +629,9 @@ def show_text(t, width=WIDTH, with_markdown=False):
st.write("Input")
show_text(prompt[0], width=40)
if len(prompt) > 1:
st.write("Target")
show_text(prompt[1], width=40)
for target in prompt[1]:
st.write("Target")
show_text(target, width=40)

#
# Must sync state at end
Expand Down
10 changes: 7 additions & 3 deletions promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ def get_fixed_answer_choices_list(self):
else:
return None

def apply(self, example, truncate=True, highlight_variables=False):
def apply(self, example, truncate=True, highlight_variables=False) -> Tuple[str, List[str]]:
"""
Creates a prompt by applying this template to an example

:param example: the dataset example to create a prompt for
:param truncate: if True, example fields will be truncated to TEXT_VAR_LENGTH chars
:param highlight_variables: highlight the added variables
:return: tuple of 2 strings, for prompt and output
:return: tuple of a string and a list of strings, for input and targets
"""
jinja = self.jinja

Expand Down Expand Up @@ -189,7 +189,11 @@ def apply(self, example, truncate=True, highlight_variables=False):

# Splits on the separator, and then replaces back any occurrences of the
# separator in the original example
return [self._unescape_pipe(part).strip() for part in rendered_example.split("|||")]
parts = [self._unescape_pipe(part).strip() for part in rendered_example.split("|||")]
if len(parts) < 2:
raise ValueError("Prompt did not produce an input and at least one target.")

return parts[0], parts[1:]

pipe_protector = "3ed2dface8203c4c9dfb1a5dc58e41e0"

Expand Down
2 changes: 1 addition & 1 deletion test/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_dataset(dataset):
# Check 2: Prompt/output separator present?
if "|||" not in template.jinja:
raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
f"with uuid {template.get_id()} has no prompt/output separator.")
f"with uuid {template.get_id()} has no input/target separator.")

# Check 3: Unique names and templates?
if template.get_name() in template_name_set:
Expand Down