Skip to content

and_broadcast_mut for Zip #478

Closed
@millardjn

Description

@millardjn

Currently Zip has no mutable equivalent for the and_broadcast method, which is presumably to allow it to safely be split() and parallelized. However this makes some calculations difficult. The examples below are for backprop through a broadcast-able element-wise multiplication, but the problem appears as a pattern often.

Currently I'm looking at using this:

unsafe{
	let input2_grad = data.get_mut(&self.input2_id.gradient_id())?;

	// do not split/parallelise this Zip!
	Zip::from(&input1)
		.and(&output_grad)
		.and_broadcast(&input2_grad)
		.apply(|input1, out_grad, input2_grad| {
			let input2_grad = input2_grad as *const f32 as *mut f32;
			*input2_grad += input1 * out_grad;
		});
}

which works correctly, but also has UB written all over it depending on how Zip is implemented internally (I realize Zip owes me nothing here).

The safe alternative I've found is:

let iter = input1.exact_chunks(input2.shape()).into_iter()
	.zip(output_grad.exact_chunks(input2.shape()));

for (input1_chunk, out_grad_chunk) in iter {
	Zip::from(&mut input2_grad)
		.and(&input1_chunk)
		.and(&out_grad_chunk)
		.apply(|input2_grad, input1, out_grad| {
			*input2_grad += input1 * out_grad;
		});
}

which is catastrophically (over 30x) slower for some cases.

So my questions are:

  • is there a safe high performance workaround that I've missed?
  • would it be possible to add a and_broadcast_mut() method which would set a flag or something which makes split() panic?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions