ctx vs self
2021/2/23
需要补充的pytorch知识好多orz
ctx
可能是context
的缩写,它出现在静态方法(static method,由@staticmethod
修饰)里。这种静态方法和类外定义的方法一致(而且不会用到类里的属性),只是放在类内更合适一些(分类意义?)。使用这种静态方法时,直接用类名调用方法,例如:
LinearFunction.backward(x, y)
因为没有实例化类,所以也就没有self
可用,这里ctx
就是个调用函数需要传入的常规参数。
引一下doc背书,ctx
实际上是implicit argument:
A static method does not receive an implicit first argument. When function
decorated with @staticmethod is called, we don't pass an instance of the
class to it (as we normally do with methods). This means we can put a
function inside a class but we can't access the instance of that class (this
is useful when your method does not use the instance).
(update)这里有篇更好的blog…非常有条理:
-
ctx
是context的缩写, 翻译成"上下文; 环境" -
ctx
专门用在静态方法中 -
self
指的是实例对象; 而ctx
用在静态方法中, 调用的时候不需要实例化对象, 直接通过类名就可以调用, 所以self
在静态方法中没有意义 - 自定义的
forward()
方法和backward()
方法的第一个参数必须是ctx
;ctx
可以保存forward()
中的变量,以便在backward()
中继续使用, 下一条是具体的示例 -
ctx.save_for_backward(a, b)
能够保存forward()
静态方法中的张量, 从而可以在backward()
静态方法中调用, 具体地, 下面地代码通过a, b = ctx.saved_tensors
重新得到a和b -
ctx.needs_input_grad
是一个元组, 元素是True
或者False
, 表示forward()
中对应的输入是否需要求导, 比如ctx.needs_input_grad[0]
指的是下面forward()
代码中indices
是否需要求导
代码也嫖过来…
class SpecialSpmmFunction(torch.autograd.Function):
"""
Special function for only sparse region backpropataion layer.
"""
# 自定义前向传播过程
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)
# 自定义反向传播过程
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b
这个知乎问题进一步讨论了ctx.save_for_backward
和直接ctx.input = input
的区别,但是并没有很懂。