Meguhine's blog

By Meguhine, history, 10 months ago, In English

Could someone help me with it? My code gets MLE on pretest 8.

Code Link

Or just see the code pasted below.

#include<bits/stdc++.h>
using namespace std;

#define ONLINE
#ifndef ONLINE
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) ;
#endif

using LL=long long;
using PII=pair<int,int>;
using PLI=pair<LL,int>;

template<typename T>
inline T READ(){
	T x=0; bool f=0; char c=getchar();
	while(c<'0' || c>'9') f|=(c=='-'),c=getchar();
	while(c>='0' && c<='9') x=x*10+c-'0',c=getchar();
	return f?-x:x;
}
inline int read(){return READ<int>();}
inline LL readLL(){return READ<LL>();}
mt19937 rng(chrono::system_clock::now().time_since_epoch().count());

void bf(int n,int m){
	vector<vector<int>>a(n);
	for(int i=0;i<n;i++){
		a[i].resize(n);
		for(int j=i;j<n;j++) a[i][j]=read();
	}
	vector<LL>ans;
	for(int i=0;i<(1<<n);i++){
		int l=-2,r=-2; LL res=0;
		for(int j=0;j<n;j++) if(i>>j&1){
			if(j!=r+1){
				if(l>=0 && r>=0) res+=a[l][r];
				l=r=j;
			}
			else r++;
		}
		if(l>=0 && r>=0) res+=a[l][r];
		ans.push_back(res);
	}
	sort(ans.rbegin(),ans.rend());
	debug("***\t");
	for(int i=0;i<m;i++) printf("%lld ",ans[i]);
	printf("\n");
}
void solve(){
	int n=read(),m=read();
	if(n*(n+1)/2<=m){//brute force
		bf(n,m);
		return;
	}
	vector<vector<int>>a(n+1);
	for(int i=1;i<=n;i++){
		a[i].resize(n+1);
		for(int j=i;j<=n;j++) a[i][j]=read();
	}
	multiset<int>TMP;
	vector<vector<int>>dp(n+3);
	dp[n+1].push_back(0);
	dp[n+2].push_back(0);
	vector<int>p(n+1);
	multiset<PII>tmp;
	for(int i=n;i>0;i--){
		TMP={dp[i+1].begin(),dp[i+1].end()};
		debug("i=%d\n",i);
		tmp.clear();
		for(int j=i;j<=n;j++){
			p[j]=dp[j+2].size()-1;
			tmp.insert({dp[j+2][p[j]]+a[i][j],j});
		}
		for(int t=0;t<m && tmp.size();t++){
			auto [val,j]=*tmp.rbegin();
			if(TMP.size()<m) TMP.insert(val);
			else if(*TMP.begin()<val){
				TMP.erase(TMP.begin());
				TMP.insert(val);
			}
			else break;
			tmp.erase(--tmp.end());
			if(p[j]!=0){
				p[j]--;
				tmp.insert({dp[j+2][p[j]]+a[i][j],j});
			}
		}
		dp[i]={TMP.begin(),TMP.end()};
	}
	debug("\t");
	for(auto it=dp[1].rbegin();it!=dp[1].rend();it++) printf("%d ",*it);
	printf("\n");
}

int main(){
	for(int T=read();T--;) solve();
	return 0;
}

/* stuff you should look for
* int overflow, array bounds
* special cases (n=1?)
* do smth instead of nothing and stay organized
* WRITE STUFF DOWN
* DON'T GET STUCK ON ONE APPROACH
*/

Full text and comments »

  • Vote: I like it
  • 0
  • Vote: I do not like it