Интерполяция многочлена от нескольких переменных, или как угадать формулу
Difference between ru8 and ru9, changed 68 character(s)
Всем привет. Я написал программу, которая использует [эту теорему](https://ru.wikipedia.org/wiki/Комбинаторная_теорема_о_нулях) для интерполяции функции как многочлена. Сначала я расскажу об устройстве программы, а потом приведу несколько способов применения. ↵

Код получился длинный, но пользоваться им просто.↵

Как использовать программу↵
------------------↵

Сначала нужно задать значения констант. N — это количество переменных в Вашем многочлене, MAX_DEG — это максимально возможная степень переменной, в которой она может входить в какой-либо из одночленов. В функции `main` нужно заполнить два массива N элементами: names содержит имена всех переменных, max_exp на i-той позиции содержит максимальный показатель степени (или оценку сверху на него), который может иметь соответствующая переменная. ↵

Обозначим `d = (max_exp[0] + 1) * (max_exp[1] + 1) * ... * (max_exp[N - 1] + 1)`. Должно выполняться, что константа MAX_PRODUCT больше, чем d. Дальше нужно написать функцию f, которая на вход принимает `array<ll, N>`, а возвращает ll или ld. В моём примере, результат работы функции &mdash; целое число, но функция возвращает ld для того, чтобы избежать переполнений ll.↵

<spoiler summary="Код">↵
~~~~~↵
#include <bits/stdc++.h>↵
using namespace std;↵
#define ll long long↵
#define ld long double↵
#define fi first↵
#define se second↵
#define pb push_back↵
#define cok cout << (ok ? "YES\n" : "NO\n");↵
#define dbg(x) cout << (#x) << ": " << (x) << endl;↵
#define dbga(x,l,r) cout << (#x) << ": "; for (int ii=l;ii<r;ii++) cout << x[ii] << " "; cout << endl;↵
// #define int long long↵
#define pi pair<int, int>↵
const int N = 7, C = 1e7, MAX_DEG = 4, MAX_PRODUCT = 1e5;↵
const ld EPS = 1e-9, EPS_CHECK = 1e-9;↵
const string SEP = "  (", END = ")\n";↵
const bool APPROXIMATION = true;↵
array <string, N> names;↵
array <int, N> max_exp, powers, current_converted, cur_exp;↵
array<vector<ll>, N> POINTS;↵
ll DIV[N][MAX_DEG + 1][MAX_DEG + 1], PW[N][MAX_DEG + 1][MAX_DEG + 1];↵
ld SUM[MAX_PRODUCT];↵
ld F_CACHE[MAX_PRODUCT];↵
ll pow(ll a, int b)↵
{↵
if (b == 0) return 1;↵
if (b == 1) return a;↵
ll s = pow(a, b / 2);↵
s *= s;↵
if (b & 1) s *= a;↵
return s;↵
}↵
ld approximate(ld k)↵
{↵
int k_ = k;↵
int k__ = k_ + abs(k) / k;↵
if (abs(k - k_) < EPS) return k_;↵
else if (abs(k - k__) < EPS) return k__;↵
else↵
{↵
int i = 1, j = 1;↵
ld ka = abs(k);↵
while (i < C && j < C)↵
{↵
ld p = ka * j;↵
if (abs(p - i) < EPS) break;↵
if (p < i) j++;↵
else i++;↵
}↵
if (i >= C || j >= C) return k;↵
if (k < 0) i = -i;↵
return (ld)i / j;↵
} ↵
}↵
void normalize(ld k)↵
{↵
    if (!APPROXIMATION)↵
    {↵
        cout << k << SEP;↵
        return;↵
    }↵
int k_ = k;↵
int k__ = k_ + abs(k) / k;↵
if (abs(k - k_) < EPS) cout << k_ << SEP;↵
else if (abs(k - k__) < EPS) cout << k__ << SEP;↵
else↵
{↵
int i = 1, j = 1;↵
ld ka = abs(k);↵
while (i < C && j < C)↵
{↵
ld p = ka * j;↵
if (abs(p - i) < EPS) break;↵
if (p < i) j++;↵
else i++;↵
}↵
if (i >= C || j >= C)↵
{↵
cout << k << SEP;↵
return;↵
}↵
if (k < 0) i = -i;↵
cout << i << "/" << j << SEP;↵
}↵
}↵
struct monom↵
{↵
array<int, N> exp;↵
ld k;↵
int deg;↵
monom(array<int, N> v, ld k_)↵
{↵
k = k_;↵
exp = v;↵
deg = 0;↵
for (int i=0;i<N;i++) deg += exp[i];↵
}↵
void display()↵
{↵
normalize(k);↵
if (deg == 0) { cout << "1" << END; return;}↵
bool go = 0;↵
for (int i=0;i<N;i++)↵
{↵
if (go && exp[i]) cout << " * ";↵
if (exp[i]) go = 1, cout << names[i] + "^" + to_string(exp[i]);↵
}↵
cout << END;↵
}↵
ld operator()(array<int, N> v)↵
{↵
ll res = 1;↵
for (int i=0;i<N;i++) res *= PW[i][v[i]][exp[i]];↵
return k * res;↵
}↵
ld getRandom(array<ll, N> v)↵
{↵
ld res = 1;↵
for (int i=0;i<N;i++) res *= pow(v[i], exp[i]);↵
return k * res;↵
}↵
};↵
bool operator<(monom a, monom b)↵
{↵
if (a.deg > b.deg) return 1;↵
if (a.deg < b.deg) return 0;↵
if (a.exp > b.exp) return 1;↵
if (a.exp < b.exp) return 0;↵
return a.k > b.k;↵
}↵
struct polynom↵
{↵
vector<monom> st;↵
void add(monom m)↵
{↵
if (abs(m.k) < EPS) return;↵
st.pb(m);↵
}↵
void print() { if(st.size() == 0) {cout << "Polynom is 0\n"; return;} sort(st.begin(), st.end()); for (monom &m: st) m.display();}↵
ld operator()(array<ll, N> v)↵
{↵
ld res = 0;↵
for (auto &m: st) res += m.getRandom(v);↵
return res;↵
}↵
};↵
ld gen(int index=0, int current_hash=0)↵
{↵
if (index == N)↵
{↵
ll div = 1;↵
for (int i=0;i<N;i++) div *= DIV[i][current_converted[i]][cur_exp[i]];↵
return (ld)(F_CACHE[current_hash] - SUM[current_hash]) / div;↵
}↵
ld res = 0;↵
for (int i=0;i<=cur_exp[index];i++)↵
{↵
current_converted[index] = i;↵
res += gen(index + 1, current_hash + i * powers[index]);↵
}↵
return res;↵
}↵
array<int, N> convert(int h)↵
{↵
array<int, N> res;↵
for (int i=0;i<N;i++) res[i] = h / powers[i], h -= res[i] * powers[i];↵
return res;↵
}↵
array<ll, N> convert_points(int h)↵
{↵
array<ll, N> res;↵
for (int i=0;i<N;i++) res[i] = POINTS[i][h / powers[i]], h %= powers[i];↵
return res;↵
}↵
polynom interpolate(ld f(array<ll, N>))↵
{↵
    int max_pow = -2e9, sum = 0, h_max = 0;↵
    set<int> remaining_points, st;↵
polynom res;↵
    for (int x: max_exp) max_pow = max(max_pow, x), sum += x, h_max = h_max * (x + 1) + x;↵

    powers[N - 1] = 1;↵
    for (int i=N-2;i>-1;i--) powers[i] = powers[i + 1] * (max_exp[i + 1] + 1);↵

    for (int i=0;i<max_exp.size();i++) for (int j=0;j<=max_exp[i];j++) POINTS[i].pb(j);↵

    for (int i=0;i<N;i++) for (int j=0;j<=max_exp[i];j++) for (int u=0;u<=max_exp[i];u++) DIV[i][j][u] = (u ? DIV[i][j][u - 1] : 1) * (u == j ? 1 : (POINTS[i][j] - POINTS[i][u]));↵

    for (int i=0;i<N;i++) for (int j=0;j<=max_exp[i];j++) for (int u=0;u<=max_pow;u++) PW[i][j][u] = u ? PW[i][j][u - 1] * POINTS[i][j] : 1;↵

    for (int i=0;i<=h_max;i++) F_CACHE[i] = f(convert_points(i)), remaining_points.insert(i);↵
    st.insert(h_max);↵

    while (st.size())↵
{↵
int v = *st.rbegin();↵
st.erase(v);↵
remaining_points.erase(v);↵
cur_exp = convert(v);↵
ld k = gen();↵
if (abs(k) > EPS)↵
{↵
monom mn = monom(cur_exp, k);↵
if (APPROXIMATION) k = approximate(k);↵
monom mn = monom(cur_exp, k);↵
res.add(mn);↵
for (int i: remaining_points) SUM[i] += mn(convert(i));↵
}↵
for (int i=0;i<N;i++) if (cur_exp[i]) st.insert(v - powers[i]);↵
}↵
return res;↵
}↵
ld f(array<ll, N> v)↵
{↵
auto [a, b, c, d, e, f, g] = v;↵
ld res = 0;↵
for (int i=0;i<a;i++)↵
for (int j=0;j<b;j++)↵
for (int u=0;u<c;u++)↵
for (int x=0;x<d;x++)↵
for (int y=0;y<e;y++)↵
for (int z=0;z<f;z++)↵
for (int k=0;k<g;k++)↵
res += 13ll * i * j * u * i * i * u - 49ll * k * k * z * z * y + 90ll * c * u * k * x * x * x;↵
return res;↵
}↵
void check(polynom p, ld(array<ll, N> f))↵
{↵
mt19937 rnd(228);↵
for (int i=0;i<10000;i++)↵
{↵
int t = clock();↵
array<ll, N> ex;↵
for (int j=0;j<N;j++) ex[j] = rnd() % 20 + 2;↵
ld F = f(ex);↵
ld P = p(ex);↵
if (abs(F - P) > max(EPS_CHECK, EPS_CHECK * abs(F)))↵
{↵
cout << "Polynom is wrong, test " << i << endl;↵
cout << F << endl << P << endl;↵
for (int x: ex) cout << x << " ";↵
cout << endl;↵
return;↵
}↵
cout << "Test " << i << " has been passed, time = " << (ld)(clock() - t) / CLOCKS_PER_SEC << "s" << endl;↵
}↵
cout << "Polynom is OK" << endl;↵
}↵
signed main()↵
{↵
    cin.tie(0); ios_base::sync_with_stdio(0);↵
    cout << setprecision(20) << fixed;↵

    names = {"a", "b", "c", "d", "e", "f", "g"};↵
    max_exp = {4, 2, 3, 4, 2, 3, 3};↵
    ↵
    polynom P = interpolate(f);↵
    P.print();↵
    //cout << "Checking polynom..." << endl;↵
    //check(P, f);↵
}↵
~~~~~↵
</spoiler>↵

#### Стрессы↵
Если раскомментировать две последние строки в main, то программа сама проверит получившийся многочлен на случайных тестах. Генерацию тестов нужно изменять под конкретную функцию f, иначе она может долго вычисляться на больших тестах.↵

#### Приближения↵
Функция из примера (и все подобные функции с N циклами) является многочленом с рациональными коэффициентами (иначе целое число на выходе мы не получим). Поэтому, в случае APPROXIMATION = true, все коэффициенты приближаются к рациональным с абсолютной погрешностью EPS при помощи функций normalize и approximate. Приближения к рациональным дробям выполняются, вероятно, не самым эффективным алгоритмом за O(числитель + знаменатель), но при небольшом количестве мономов в многочлене это недолго.↵

Функция стресс-тестирования считает результат вычисления многочлена корректным, если его абсолютная или относительная погрешность не больше, чем EPS_CHECK.↵

### Как и за сколько времени это работает↵

Мономы мы представляем в виде массива показателей степеней переменных, которые мы хэшируем. Массив PW &mdash; предпосчёт степеней, в которые возводим числа в массиве POINTS &mdash; собственно, точки, по которым мы интерполируем. Если Вы хотите задать свои точки для интерполяции, то нужно изменить массив POINTS. Если там будут дробные числа, то в начале программы нужно заменить `#define ll long long` на `#define ll long double`.↵
Массив DIV служит для быстрого вычисления знаменателей в формуле коэффициента.↵

`convert(h)` &mdash; получить индексы координат точки в массиве POINTS, соответствующей моному с хэшом h↵
`convert_points(h)` &mdash; получить координаты точки, соответствующей моному с хэшом h.↵

Далее мы предподсчитываем значения функции f во всех наших точках и записываем их в массив `F_CACHE`. Потом мы запускаем bfs по мономам, где мы при переходе от одного монома к другому уменьшаем показатель степени одной из переменных на 1.↵
Приходя в bfs'е к моному, мы находим коэффициент при нём при помощи функции `gen`. Если коэффициент ненулевой, то мы должны изменить наш многочлен для всех ещё не пройденных мономов. (Здесь мы не разделяем понятия монома и точки, так как из показателей степеней монома мы можем получить N координат точки при помощи функции `convert_points(h)`, где h &mdash; хэш монома). Это нужно для того, чтобы выполнялось одно из условий теоремы: в многочлене не должно быть мономов старше нашего. Мы для каждой точки добавляем в массив SUM значение в этом мономе, чтобы потом в функции `gen` его вычесть из результата работы функции f, для того чтобы искусственно убрать старшие мономы.↵

#### Время↵
1. Самая долгая часть предподсчета &mdash; вычисление F_CACHE &mdash; работает за O(d * O(f))↵
2. Каждый из d запусков функции gen перебирает каждую из O(d) точек за O(N)↵
3. Для каждого монома с ненулевым коэффициентом мы считаем его значение в каждой из O(d) точек за O(N)↵

Получили `O(d * O(f) + d^2 * N + d * O(res))`, где `O(res)` &mdash; время для вычисления полученного в результате многочлена.↵

### Попытка оптимизировать↵

Скорее всего, больше всего времени будет занимать рекурсия. Её можно развернуть в цикл со стеком. Это скучно, и я решил узнать, что будет, если её развернуть просто в цикл. Давайте вместо запуска рекурсии пробежимся по всем хэшам мономов, меньших нашего. Для каждого монома проверим, является ли он младше нашего (все соответствующие показатели степеней небольше). Если младше, то добавляем к текущему коэффициенту значение дроби для этой точке. Код будет какой-то такой:↵

~~~~~↵
// Вместо ld k = gen();↵
ld k = 0;↵
for (int h=0;h<=v;h++)↵
{↵
    array<int, N> cur = convert(h);↵
    bool ok = 1;↵
    for (int i=0;i<N;i++) if (cur[i] > cur_exp[i]) ok = 0;↵
    if (ok)↵
    {↵
ll div = 1;↵
        for (int i=0;i<N;i++) div *= DIV[i][cur[i]][cur_exp[i]];↵
        k += (ld)(F_CACHE[h] - SUM[h]) / div;↵
    }↵
}↵
~~~~~↵


Будет ли это быстрее? Новая реализация перебирает по 1 разу каждую пару хэшей, поэтому она работает за `O(d^2 * N)`, как и функция `gen`. Теперь оценим константу. Пар хэшей существует d * (d + 1) / 2. Константа 1 / 2. Чему равна константа количества рассмотренных точек функции gen? По сути, это количество можно посчитать при помощи функции:↵


~~~~~↵
ld f(array<ll, N> v)↵
{↵
auto [a, b, c, d, e, f, g] = v;↵
ld res = 0;↵
for (int i=0;i<a;i++)↵
for (int j=0;j<b;j++)↵
for (int u=0;u<c;u++)↵
for (int x=0;x<d;x++)↵
for (int y=0;y<e;y++)↵
for (int z=0;z<f;z++)↵
for (int k=0;k<g;k++)↵
res += (i + 1) * (j + 1) * (u + 1) * (x + 1) * (y + 1) * (z + 1) * (k + 1);↵
return res;↵
}↵
~~~~~↵

Коэффициент при a^2 * b^2 * c^2 * d^2 * e^2 * f^2 и будет нашей константой. Для нахождения этого коэффициента я воспользовался своей программой. Он оказался равен 1/128. Вообще, для N переменных он равен 1 / 2^N. То есть способ оптимизации эффективен для очень маленьких N.↵

### Заключение↵

Возможно, кому-то эта программа поможет узнать формулу для какой-то функции. Также она может раскрывать скобки, что необходимо при счёте геометрии в комплексных числах. Если Вы придумали другие способы использования, то я буду рад, если Вы ими поделитесь.↵

При N = 1 эта программа &mdash; просто интерполяция по Лагранжу, для которой существует реализация быстрее, чем за квадрат. Возможно, кто-нибудь сможет придумать ускорение и при N > 1.↵

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
en2 English polosatic 2023-05-12 14:42:44 72
ru9 Russian polosatic 2023-05-12 14:41:57 68
en1 English polosatic 2023-05-10 20:14:30 13022 Initial revision for English translation
ru8 Russian polosatic 2023-05-10 19:39:53 16 Мелкая правка: 'N = 1 эта функция &mdash; п' -> 'N = 1 эта программа &mdash; п'
ru7 Russian polosatic 2023-05-10 19:36:11 16 Мелкая правка: 'му-то эта функция поможет у' -> 'му-то эта программа поможет у'
ru6 Russian polosatic 2023-05-10 19:31:20 54
ru5 Russian polosatic 2023-05-10 19:15:31 2 Мелкая правка: 'очек за O(n)\n\nПолуч' -> 'очек за O(N)\n\nПолуч'
ru4 Russian polosatic 2023-05-10 19:14:48 4
ru3 Russian polosatic 2023-05-10 18:21:17 5 Мелкая правка: 'ет `array<int, N>`, а в' -> 'ет `array<ll, N>`, а в'
ru2 Russian polosatic 2023-05-10 18:00:50 15 Мелкая правка: 'е, чем за `O(d^2)`. Возможно' -> 'е, чем за квадрат. Возможно'
ru1 Russian polosatic 2023-05-10 17:52:39 13140 Первая редакция (опубликовано)