Loading...
Compute the hidden state for a single timestep in a standard RNN.
The update rule for a vanilla RNN is:
at=tanh(Waa⋅at−1+Wax⋅xt+ba)
Where:
Write a function rnn_step(prev_hidden, input_vec, W_aa, W_ax, b_a) that computes and returns the new hidden state using np.tanh.
Return a numpy array of shape (hidden_dim,) representing the new hidden state.
prev_hidden=[0,0], input_vec=[1,2], W_aa=[[0.1,0.2],[0.3,0.4]], W_ax=[[0.5,0.6],[0.7,0.8]], b_a=[0,0]
[0.9354, 0.9801] (approximately)
tanh(W_aa @ prev_h + W_ax @ x + b_a) = tanh([1.7, 2.3])