上QQ阅读APP看书,第一时间看更新
2.6.2 函数的复制
在实际应用中可能会遇到多个函数之间功能相似,但是参数不同的情况,这时就可以用到Theano中的函数复制功能。例如,用同一个算法对不同的模型进行训练,不同的模型之间采用的训练参数是不同的,这时可以通过函数复制功能将一个函数复制给另一个函数。这两个函数之间具有独立的计算图结构,相互之间并不会有影响。在Theano中,通过copy函数实现函数的复制。
以累加器为例,下面的函数通过定义一个共享变量state(2.6.3节将会详细介绍共享变量)来累加变量,每一次调用函数accumulator时,state的值都会发生变化。
import theano import theano.tensor as T state = theano.shared(0) inc = T.iscalar('inc') accumulator = theano.function([inc],state,updates = [(state,state+inc)])
接下来,通过调用accumulator函数来查看输出结果。
>>>accumulator(10) array(0) >>>state.get_value() array(10)
新建另一个函数new_accumulator,它实现的功能与accumulator函数完全相同,但累加的变量不同。new_accumulator是定义在new_state上的累加函数,通过copy函数来实现这个功能,通过swap参数来交换两个共享变量。
>>>new_state = theano.shared(0) >>>new_accumulator = accumulator.copy(swap = {state:new_state})
验证结果如下:
>>>new_accumulator(100) array(0) >>>new_state.get_value() array(100) >>>state.get_value() array(10)
从上面的运行结果可以看出,new_accumulator函数没有对原来的state进行修改。如果只想在原来函数的基础上去除共享变量的更新,同样可以通过copy函数来实现这个功能,通过delete_updates参数来实现该功能。
>>>null_accumulator = accumulator.copy(delete_updates = True)
验证结果如下:
>>>null_accumulator(9000) [array(10)] >>>state.get_value() array(10)
从上述结果可以看出,调用null_accumulato函数并没有影响state变量,实际上无论何时调用该函数都会输出同样的结果。