cercatrova's blog

By cercatrova, history, 17 months ago, In English

In the problem Jewel-eating Monsters if the traveler has x coin in the evening and drops one coin in the pond at midnight then in the morning they will have (x-1)*a coins. With the total coin if they buy diamonds where each costs c coins, remaining coins should be total_coins % c.

Applying the same logic if the traveler repeats the action for n nights,

Total_coins: 

(x-1)*a + ((x-1)*a-1)*a + (((x-1)*a-1)*a-1)*a + ...... + upto n nights
= x*a - a + x*a^2 - a^2 - a + .......+ x*a^n - a^n - a^(n-1) - .......... - a
=x*(a + a^2 + a^3 + ..... + a^n) - (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n)

Now for (a + a^2+ a^3 + ..... + a^n) = a*(a^n - 1)/(a-1)

For (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n),
Suppose,
Sn = n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n
Multiply a with both side,
a*Sn = n*a^2 + (n-1)*a^3 + .......... + 2*a^n + a^(n+1)

Now,
Sn - a*Sn = (n*a + (n-1)*a^2 + (n-2)*a^3 + ....... + a^n) - (n*a^2 + (n-1)*a^3 + .......... + 2*a^n + a^(n+1))
=n*a - a^2 - a^3 - ........... - a^n - a^(n+1)
=(n+1)*a - a - a^2 - a^3 - ........... - a^n - a^(n+1)
=(n+1)*a - (a + a^2 + a^3 + ........... + a^n) - a^(n+1)
=(n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1)

Or, Sn(1-a) = (n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1)
Or, Sn = -((n+1)*a -  a*(a^n - 1)/(a-1) - a^(n+1))/(a-1)

Hence, Total_coins=x * a*(a^n - 1)/(a-1) - Sn

And remaining coins after buying diamonds = Total_coins % price_of_a_single_diamond


I tried to implement the same logic in my code. It is giving correct result for the first test case which is 357 but the rests are not matching. What mistakes am I making in the code or the math?
My code:

#include<bits/stdc++.h>
#define fast ios::sync_with_stdio(false)
using namespace std;

typedef long long ll;
ll mod_ex(ll a,ll b, ll mod){
    ll res=1;
    while(b){
        if(b%2){
            res=(res*a)%mod;
        }
        a=(a*a)%mod;
        b/=2;
    }
    return res;
}
int main()
{
    fast;
    ll x,a,n,c;
    while(true){
        cin>>x>>a>>n>>c;
        if(x+a+n+c==0)break;

        ll p=((a%c)*(mod_ex(a,n,c)-1))%c;
        // according to fermat's little theorem, (a/b)%c = (a * b^(c-2))%c when c is prime
        ll q=(p*mod_ex(a-1,c-2,c))%c;

        ll xq=(x*q)%c;
        ll yq=-(((n+1)*a-q-mod_ex(a,n+1,c))%c*mod_ex(a-1,c-2,c))%c;
        ll ans=(xq-yq)%c;
        cout<<ans<<endl;
    }
    return 0;
}

I am quite beginner in competitive programming. Any constructive criticism is appreciated.

  • Vote: I like it
  • +1
  • Vote: I do not like it