Solving linear recurrence relations using Matrix Exponentiation

Many of us must be familiar with linear recurrences, its one of the frequent topics that come up in computer science and  programming contests. One of the classic problem is finding the nth term in Fibonacci sequence. Using recurrence relation and dynamic programming we can calculate the nth term in O(n) time. But many times n is very large  (of the order > 1010) that we need to calculate the nth in O(log n) time. This is where Matrix Exponentiation comes in handy.  All linear recurrences can be converted to matrices with sufficiently large dimensions.

Linear recurrences
A linear recurrence is a sequence in which each term (apart from few initial ones) is linear combination of previous terms.
Well know example is Fibonacci sequence
f(i)=f(i-1)+f(i-2) 
Another example is Tribonacci sequence
f(i)=f(i-1)+f(i-2)+f(i-3)

Fibonacci sequence is already very popular by itself, we will take tribionacci sequence for understanding the concepts.

Solving the problem (Tribonacci sequence)
First get the initial terms that are free from recurrence, or initial values as few may put it.
for Tribonacci sequence
f(0)=0
f(1)=0
f(2)=1
and for all i > 2,
we have the recurrence f(i)=f(i-1)+f(i-2)+f(i-3)
Lets call a constant K ,where k is the number of previous terms ith term is dependent on, like here K = 3.
Defining column vector Fi as a K x 1 matrix whose first row is f(i), second row is f(i+1), and so on, until K-th row is f(i+K-1). The initial values of f are given in column vector F1 that has values f(1) through f(K):

            
            |  f(0)    |
            |  f(1)    |
      F(1)= | ......   | 
            | ......   |
            |  f(k-1)  |

For Tribonacci sequence

            |   0    |
      F(1)= |   0    |
            |   1    |

now what we are interested in is this relation Fi+1 = M Fi  for some constant matrix M.
So our aim is to find this constant matrix M, for a given recurrence relation.

For a recurrence relation where the next term is dependent on last K terms, Fi+1 and Fi are matrices of size 1 x K and M is a matrix of size K x K.

| f(n+1)   |       | f(n)   |
|  f(n)    |       | f(n-1) |
| f(n-1)   | = M x | f(n-2) |
| ......   |       | ...... |
| f(n-K+1) |       | f(n-K) |

Lets check these things for tribonacci sequence.
Relation is f(i)=f(i-1)+f(i-2)+f(i-3)

| f(n+1)   |       | f(n)   |
|  f(n)    | = M x | f(n-1) |
| f(n-1)   |       | f(n-2) |

Now we know that M is a 3 x 3 matrix, so

| f(n+1) | = | a b c | x | f(n)   |
| f(n)   |   | d e f |   | f(n-1) |
| f(n-1) |   | g h i |   | f(n-2) |

We need to find the values of these unknown constants in matrix M,
so the series is 0, 0, 1, 1, 2, 4, 7, 13, 24, 44, 81, 149, 274, 504……..
we have 3 equations
f(n+1) = a*f(n)+b*f(n-1)+c*f(n-2)…………….eq 1
f(n)   = d*f(n)+e*f(n-1)+f*f(n-2)…………….eq 2
f(n-1)  = g*f(n)+h*f(n-1)+i*f(n-2)……………eq 3

suppose we take value of n = 2 in eq 1
we have f(3) = a*f(2)+b*f(1)+c*f(0)
by putting few initial terms of series together manually
f(0)=0, f(1)=0, f(2)=1, f(3)=1
solving, it gives a = 1.
Similarly we can carefully take values of n so that we can calculate the values of unknowns.
we can easily come up with value of matrix M

| f(n+1) | = | 1 1 1 | x | f(n)   |
| f(n)   |   | 1 0 0 |   | f(n-1) |
| f(n-1) |   | 0 1 0 |   | f(n-2) |

Now that we have got our matrix M, its just fun to find the nth term.
F2 = M F1
F3 = M F2=M2F1
and like this..
Fn = Mn-1 F1

our answer would be Fn[0]
So all we need now is to find the matrix Mn-1 to find the nth term.
The most popular method is to use exponentiation by squaring method that works in O(log N) time, with this recurrence:
Be=B if e = 1
Be=B*Be-1 if e is odd
Be=X2 where X = Be/2 if e is even
The multiplication part takes the O(K3) time and hence the overall complexity is O(K3 log n).

Here’s the source code for finding the nth term of Tribonacci sequence.
For the purpose of convenience, it would be nth term modulo 109+7

/* Amrendra Kumar */
#include<iostream>
#include<cmath>
#include<algorithm>
#include<climits>
#include<vector>
#include<cstdio>
#include<ctime>

using namespace std;
#define MOD 	 1000000007LL
#define LL 		 long long

#define FORD(i,a,b,d)   for(__typeof(b) i=(a);i<(b);i+=(d))
#define FOR(i,a,b)      for(__typeof(b) i=(a);i<(b);++i)
#define FORE(i,a,b)     for(__typeof(b) i=(a);i<=(b);++i)
#define REP(i,n)        for(__typeof(n) i=0;i<(n);i++)
#define FORR(i,n,e)     for(__typeof(n) i=(n);i>=(e);--i)
#define FORRD(i,n,e,d)  for(__typeof(n) i=(n);i>=(e);i-=(d))
#define FORI(it,s) 	    for(__typeof((s).begin()) (it)=(s).begin();(it)!=(s).end();(it)++)

//template can be used for multiplying the matrix.

//for multiplying the two 3*3 matrix
void multiplyMM(LL F[3][3], LL M[3][3]){
    LL l[3][3];
    l[0][0]=((F[0][0]*M[0][0])%MOD+(F[0][1]*M[1][0])%MOD+(F[0][2]*M[2][0])%MOD)%MOD;
    l[0][1]=((F[0][0]*M[0][1])%MOD+(F[0][1]*M[1][1])%MOD+(F[0][2]*M[2][1])%MOD)%MOD;
    l[0][2]=((F[0][0]*M[0][2])%MOD+(F[0][1]*M[1][2])%MOD+(F[0][2]*M[2][2])%MOD)%MOD;
    l[1][0]=((F[1][0]*M[0][0])%MOD+(F[1][1]*M[1][0])%MOD+(F[1][2]*M[2][0])%MOD)%MOD;
    l[1][1]=((F[1][0]*M[0][1])%MOD+(F[1][1]*M[1][1])%MOD+(F[1][2]*M[2][1])%MOD)%MOD;
    l[1][2]=((F[1][0]*M[0][2])%MOD+(F[1][1]*M[1][2])%MOD+(F[1][2]*M[2][2])%MOD)%MOD;
    l[2][0]=((F[2][0]*M[0][0])%MOD+(F[2][1]*M[1][0])%MOD+(F[2][2]*M[2][0])%MOD)%MOD;
    l[2][1]=((F[2][0]*M[0][1])%MOD+(F[2][1]*M[1][1])%MOD+(F[2][2]*M[2][1])%MOD)%MOD;
    l[2][2]=((F[2][0]*M[0][2])%MOD+(F[2][1]*M[1][2])%MOD+(F[2][2]*M[2][2])%MOD)%MOD;
    FOR(i,0,3){
        FOR(j,0,3){
            F[i][j]=l[i][j];
        }
    }
}

/*
DONT worry about the modulo thing, its just that I used the same code at CodeChef,
so a bit untidy, but I hope you'll get the concept
*/

//for multiplying the 1*3 and 3*3  matrix
void multiplyMF(LL m[3][3],LL a[3]){
    LL x = ((a[0]*m[0][0])%MOD + ((a[1]*m[0][1])%MOD)+ ((a[2]*m[0][2])%MOD))%MOD;
    LL y = ((a[0]*m[1][0])%MOD + ((a[1]*m[1][1])%MOD)+ ((a[2]*m[1][2])%MOD))%MOD;
    LL z = ((a[0]*m[2][0])%MOD + ((a[1]*m[2][1])%MOD)+ ((a[2]*m[2][2])%MOD))%MOD;
    a[0]=x;
    a[1]=y;
    a[2]=z;

}

//exponentiation by squaring method
void powerM(LL f[3][3],LL n){
    if(n==0||n==1||n<0){
        return;
    }
    LL M[3][3]={{1,1,1},{1,0,0},{0,1,0}};

    powerM(f, n/2);
    multiplyMM(f,f);

    if( n%2 != 0 )
     multiplyMM(f, M);

}

//main function for calculating the n-th tribonacci number
LL tribonacci(LL n){
    //then taking the f(1) column vector of initial values
    LL f[3]={0,0,1};

    if(n<3){//handling initial values
        return f[n];
    }

    //taking the matrix M
    LL M[3][3]={{1,1,1},{1,0,0},{0,1,0}};
    //calculating M^(n-1)
    powerM(M,n-1);

    //multiplying (M^n-1)*column vector
    multiplyMF(M,f);
    return(f[0]);

}

LL N;
void doThis(){
    scanf("%lld",&N);//scanf and printf are faster than cin and cout
    printf("%lld-->%lld\n",N,tribonacci(N));

}

int main(){
#ifdef amy
freopen("C:\\A\\in.txt","r",stdin);freopen("C:\\A\\out.txt","w",stdout);
#endif
int t;
scanf("%d",&t);//for number of test cases
while(t--){doThis();}
#ifdef amy
fprintf(stdout,"\nTIME: %.3lf sec\n",(double)clock()/(CLOCKS_PER_SEC));
#endif
return 0;
}

/*
INPUT:

21
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

OUTPUT:

0-->0
1-->0
2-->1
3-->1
4-->2
5-->4
6-->7
7-->13
8-->24
9-->44
10-->81
11-->149
12-->274
13-->504
14-->927
15-->1705
16-->3136
17-->5768
18-->10609
19-->19513
20-->35890

TIME: 0.002 sec

*/

We may have variations in the linear recurrence relation, but then we can carefully construct our constant matrix M.

Advertisements

One thought on “Solving linear recurrence relations using Matrix Exponentiation

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

w

Connecting to %s