25
25
from ...utils import (
26
26
add_start_docstrings_to_model_forward ,
27
27
logging ,
28
- replace_return_docstrings ,
29
28
)
30
29
from ...utils .deprecation import deprecate_kwarg
31
30
from ..auto import AutoModelForImageTextToText
109
108
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
110
109
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
111
110
the complete sequence length.
112
-
113
- Returns:
114
- A `ShieldGemma2ImageClassifierOutputWithNoAttention` instance continaing the logits and probabilities
115
- associated with the model predicting the `Yes` or `No` token as the response to that prompt, captured in the
116
- following properties.
117
-
118
- * `logits` (`torch.Tensor` of shape `(batch_size, 2)`):
119
- The first position along dim=1 is the logits for the `Yes` token and the second position along dim=1 is
120
- the logits for the `No` token.
121
- * `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`):
122
- The first position along dim=1 is the probability of predicting the `Yes` token and the second position
123
- along dim=1 is the probability of predicting the `No` token.
124
-
125
- ShieldGemma prompts are constructed such that predicting the `Yes` token means the content *does violate* the
126
- policy as described. If you are only interested in the violative condition, use
127
- `violated = outputs.probabilities[:, 1]` to extract that slice from the output tensors.
128
-
129
- When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to `len(images) * len(policies)`,
130
- and the order within the batch will be img1_policy1, ... img1_policyN, ... imgM_policyN.
131
111
"""
132
112
133
113
@@ -172,9 +152,6 @@ def tie_weights(self):
172
152
173
153
@deprecate_kwarg ("num_logits_to_keep" , version = "4.50" , new_name = "logits_to_keep" )
174
154
@add_start_docstrings_to_model_forward (SHIELDGEMMA2_INPUTS_DOCSTRING )
175
- @replace_return_docstrings (
176
- output_type = ShieldGemma2ImageClassifierOutputWithNoAttention , config_class = _CONFIG_FOR_DOC
177
- )
178
155
def forward (
179
156
self ,
180
157
input_ids : torch .LongTensor = None ,
@@ -193,9 +170,26 @@ def forward(
193
170
logits_to_keep : Union [int , torch .Tensor ] = 0 ,
194
171
** lm_kwargs ,
195
172
) -> ShieldGemma2ImageClassifierOutputWithNoAttention :
196
- """Predicts the binary probability that the image violates the speicfied policy.
173
+ """Predicts the binary probability that the image violates the specified policy.
197
174
198
175
Returns:
176
+ A `ShieldGemma2ImageClassifierOutputWithNoAttention` instance containing the logits and probabilities
177
+ associated with the model predicting the `Yes` or `No` token as the response to that prompt, captured in the
178
+ following properties.
179
+
180
+ * `logits` (`torch.Tensor` of shape `(batch_size, 2)`):
181
+ The first position along dim=1 is the logits for the `Yes` token and the second position along dim=1 is
182
+ the logits for the `No` token.
183
+ * `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`):
184
+ The first position along dim=1 is the probability of predicting the `Yes` token and the second position
185
+ along dim=1 is the probability of predicting the `No` token.
186
+
187
+ ShieldGemma prompts are constructed such that predicting the `Yes` token means the content *does violate* the
188
+ policy as described. If you are only interested in the violative condition, use
189
+ `violated = outputs.probabilities[:, 1]` to extract that slice from the output tensors.
190
+
191
+ When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to `len(images) * len(policies)`,
192
+ and the order within the batch will be img1_policy1, ... img1_policyN, ... imgM_policyN.
199
193
"""
200
194
outputs = self .model (
201
195
input_ids = input_ids ,
0 commit comments