Meguhine's blog

By Meguhine, history, 10 months ago, In English

Could someone help me with it? My code gets MLE on pretest 8.

Code Link

Or just see the code pasted below.

#include<bits/stdc++.h>
using namespace std;

#define ONLINE
#ifndef ONLINE
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) ;
#endif

using LL=long long;
using PII=pair<int,int>;
using PLI=pair<LL,int>;

template<typename T>
inline T READ(){
	T x=0; bool f=0; char c=getchar();
	while(c<'0' || c>'9') f|=(c=='-'),c=getchar();
	while(c>='0' && c<='9') x=x*10+c-'0',c=getchar();
	return f?-x:x;
}
inline int read(){return READ<int>();}
inline LL readLL(){return READ<LL>();}
mt19937 rng(chrono::system_clock::now().time_since_epoch().count());

void bf(int n,int m){
	vector<vector<int>>a(n);
	for(int i=0;i<n;i++){
		a[i].resize(n);
		for(int j=i;j<n;j++) a[i][j]=read();
	}
	vector<LL>ans;
	for(int i=0;i<(1<<n);i++){
		int l=-2,r=-2; LL res=0;
		for(int j=0;j<n;j++) if(i>>j&1){
			if(j!=r+1){
				if(l>=0 && r>=0) res+=a[l][r];
				l=r=j;
			}
			else r++;
		}
		if(l>=0 && r>=0) res+=a[l][r];
		ans.push_back(res);
	}
	sort(ans.rbegin(),ans.rend());
	debug("***\t");
	for(int i=0;i<m;i++) printf("%lld ",ans[i]);
	printf("\n");
}
void solve(){
	int n=read(),m=read();
	if(n*(n+1)/2<=m){//brute force
		bf(n,m);
		return;
	}
	vector<vector<int>>a(n+1);
	for(int i=1;i<=n;i++){
		a[i].resize(n+1);
		for(int j=i;j<=n;j++) a[i][j]=read();
	}
	multiset<int>TMP;
	vector<vector<int>>dp(n+3);
	dp[n+1].push_back(0);
	dp[n+2].push_back(0);
	vector<int>p(n+1);
	multiset<PII>tmp;
	for(int i=n;i>0;i--){
		TMP={dp[i+1].begin(),dp[i+1].end()};
		debug("i=%d\n",i);
		tmp.clear();
		for(int j=i;j<=n;j++){
			p[j]=dp[j+2].size()-1;
			tmp.insert({dp[j+2][p[j]]+a[i][j],j});
		}
		for(int t=0;t<m && tmp.size();t++){
			auto [val,j]=*tmp.rbegin();
			if(TMP.size()<m) TMP.insert(val);
			else if(*TMP.begin()<val){
				TMP.erase(TMP.begin());
				TMP.insert(val);
			}
			else break;
			tmp.erase(--tmp.end());
			if(p[j]!=0){
				p[j]--;
				tmp.insert({dp[j+2][p[j]]+a[i][j],j});
			}
		}
		dp[i]={TMP.begin(),TMP.end()};
	}
	debug("\t");
	for(auto it=dp[1].rbegin();it!=dp[1].rend();it++) printf("%d ",*it);
	printf("\n");
}

int main(){
	for(int T=read();T--;) solve();
	return 0;
}

/* stuff you should look for
* int overflow, array bounds
* special cases (n=1?)
* do smth instead of nothing and stay organized
* WRITE STUFF DOWN
* DON'T GET STUCK ON ONE APPROACH
*/
  • Vote: I like it
  • 0
  • Vote: I do not like it

»
10 months ago, # |
Rev. 3   Vote: I like it +3 Vote: I do not like it

Disclaimer: I didn't test this, so please correct me if I understood your code incorrectly:

Consider a test case like this:

$$$n = 99$$$, $$$k = 5000$$$ (m in your code).

Since $$$99 \cdot 100 / 2 = 4950 \le 5000$$$, your code will run the function bf(). The vector ans will have size $$$2^{99}$$$, which gives you MLE.

  • »
    »
    10 months ago, # ^ |
      Vote: I like it +3 Vote: I do not like it

    Yes, I also discovered that just after I posted this lol.

    Ahh, I was such an idiot, I think I would pass if I just delete the brute force.

    Thank you anyway.