Why does my code get MLE on D?

Revision en1, by Meguhine, 2024-03-30 20:39:03

Could someone help me with it? My code gets MLE on pretest 8.

Code Link

Or just see the code pasted below.

#include<bits/stdc++.h>
using namespace std;

#define ONLINE
#ifndef ONLINE
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) ;
#endif

using LL=long long;
using PII=pair<int,int>;
using PLI=pair<LL,int>;

template<typename T>
inline T READ(){
	T x=0; bool f=0; char c=getchar();
	while(c<'0' || c>'9') f|=(c=='-'),c=getchar();
	while(c>='0' && c<='9') x=x*10+c-'0',c=getchar();
	return f?-x:x;
}
inline int read(){return READ<int>();}
inline LL readLL(){return READ<LL>();}
mt19937 rng(chrono::system_clock::now().time_since_epoch().count());

void bf(int n,int m){
	vector<vector<int>>a(n);
	for(int i=0;i<n;i++){
		a[i].resize(n);
		for(int j=i;j<n;j++) a[i][j]=read();
	}
	vector<LL>ans;
	for(int i=0;i<(1<<n);i++){
		int l=-2,r=-2; LL res=0;
		for(int j=0;j<n;j++) if(i>>j&1){
			if(j!=r+1){
				if(l>=0 && r>=0) res+=a[l][r];
				l=r=j;
			}
			else r++;
		}
		if(l>=0 && r>=0) res+=a[l][r];
		ans.push_back(res);
	}
	sort(ans.rbegin(),ans.rend());
	debug("***\t");
	for(int i=0;i<m;i++) printf("%lld ",ans[i]);
	printf("\n");
}
void solve(){
	int n=read(),m=read();
	if(n*(n+1)/2<=m){//brute force
		bf(n,m);
		return;
	}
	vector<vector<int>>a(n+1);
	for(int i=1;i<=n;i++){
		a[i].resize(n+1);
		for(int j=i;j<=n;j++) a[i][j]=read();
	}
	multiset<int>TMP;
	vector<vector<int>>dp(n+3);
	dp[n+1].push_back(0);
	dp[n+2].push_back(0);
	vector<int>p(n+1);
	multiset<PII>tmp;
	for(int i=n;i>0;i--){
		TMP={dp[i+1].begin(),dp[i+1].end()};
		debug("i=%d\n",i);
		tmp.clear();
		for(int j=i;j<=n;j++){
			p[j]=dp[j+2].size()-1;
			tmp.insert({dp[j+2][p[j]]+a[i][j],j});
		}
		for(int t=0;t<m && tmp.size();t++){
			auto [val,j]=*tmp.rbegin();
			if(TMP.size()<m) TMP.insert(val);
			else if(*TMP.begin()<val){
				TMP.erase(TMP.begin());
				TMP.insert(val);
			}
			else break;
			tmp.erase(--tmp.end());
			if(p[j]!=0){
				p[j]--;
				tmp.insert({dp[j+2][p[j]]+a[i][j],j});
			}
		}
		dp[i]={TMP.begin(),TMP.end()};
	}
	debug("\t");
	for(auto it=dp[1].rbegin();it!=dp[1].rend();it++) printf("%d ",*it);
	printf("\n");
}

int main(){
	for(int T=read();T--;) solve();
	return 0;
}

/* stuff you should look for
* int overflow, array bounds
* special cases (n=1?)
* do smth instead of nothing and stay organized
* WRITE STUFF DOWN
* DON'T GET STUCK ON ONE APPROACH
*/

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
en1 English Meguhine 2024-03-30 20:39:03 2614 Initial revision (published)