AOJ2255 6/2(1+2)

背景
結局は9なのでしょうか
6÷2(1+2)とは (ロクワルニカッコイチタスニカッコトジとは) [単語記事] - ニコニコ大百科

問題
6/2(1+2) | Aizu Online Judge

優先度が括弧しか決まっておらず、四則演算は任意の順で計算してよいような計算方法で、与えられた式を計算した結果のパターン数を出力せよ。
数式中の数は整数で与えられる。除算は整数除算で定義される。また、計算途中で0除算するようなケースはパターンに含めてはならない。


解法

数式の構文解析は、「再帰による演算子の優先度の解決」と「for文などによる結合性の解決」が中心となる。
まず優先度の解決について、今回は括弧しか優先度が定義されていない。
よって、再帰関数は優先度を規準にで以下の二つの部分に分けることが出来る。

1. 括弧と数字のパース
2. 演算子のパース

それぞれについて、結合性の解決をする。
1. については普通にやる。単項なので結合性もなにもない。

2. について、数式1*2+3*4-5について考える。例えば以下のように(演算子, 数字)のペアを左から順に計算したとき、これらの隣りあった式をどこから計算してもよいことになる。

(none, 1), (*, 2), (+, 3), (*, 4), (-, 5)

(none, 6), (*, 2), (+, 12), (-, 5)

(none, 6), (*, 14), (-, 5)

(none, 6), (*, 9)

(none, 54)

-> 1つのパターンとして54という値が計算される。

なので、(演算子, 数字)のペアを一列に並べたら、後はBFS等で適当な順序で計算して良いことになる。

さて、四則演算による数式の計算段階について、複数の計算結果が生じる。
言い換えると、一つの数式から上のBFSで、{(none, 値0), (none, 値1), (none, 値2), ...} すなわち、 {値0, 値1, 値2, ...} という計算結果の集合が生じる。すべての集合について演算する必要があるので、再帰関数の戻り値はintではなくsetである。

よって、二項演算の処理については、集合Aと集合Bに含まれる各数字を二項演算子で計算して、それを計算結果としてsetに詰めていく必要がある。

最後にsetのsize()を答えれば、それが計算結果のパターン数となる。

以下にコードを2つ記載する。1つ目は最近解いた冗長な解法で、上の解説と同じ解法。2つ目は昔解いた解法で、区間DP的に解いている。

namespace solver {
 
string s;
typedef string::const_iterator Iter;
Iter it, end;
 
typedef pair<set<int>, char> pic;
 
void consume(char e) {
  assert(e == *it);
  it ++;
}
 
bool isconsume(char e) {
  if(e == *it) {
    consume(e);
    return true;
  }
  return false;
}
 
int number() {
  int ret = 0;
  if(!isdigit(*it)) return -1;
  while(isdigit(*it)) {
    ret *= 10;
    ret += *it - '0';
    it ++;
  }
  return ret;
}
 
set<int> calc(set<int> const& n1, set<int> const& n2, char o) {
  set<int> ret;
  for(auto&& e: n1) for(auto&& u: n2) {
    if(o == '+') ret.insert(e + u);
    if(o == '-') ret.insert(e - u);
    if(o == '*') ret.insert(e * u);
    if(o == '/' && u != 0) ret.insert(e / u);
  }
  return ret;
}
 
void push_trans(queue<vector<pic>>& q, set<vector<pic>>& used, vector<pic> const& p, int i, set<int> const& n1, set<int> const& n2, char o, char nexto) {
  set<int> r = calc(n1, n2, o);
  auto nextp = p;
  nextp.erase(nextp.begin() + i);
  nextp.erase(nextp.begin() + i);
  nextp.insert(nextp.begin() + i, {r, nexto});
  if(used.count(nextp)) return;
  used.insert(nextp);
  q.push(nextp);
}
 
set<int> expr() {
  vector<pic> v;
  while(1) {
    int num = number();
    set<int> nums;
    if(num == -1 && isconsume('(')) {
      nums = expr();
      consume(')');
    }
    else if(num == -1) {
      assert(0);
    }
    else {
      nums.insert(num);
    }
 
    if(isconsume('+')) v.push_back({nums, '+'});
    else if(isconsume('-')) v.push_back({nums, '-'});
    else if(isconsume('*')) v.push_back({nums, '*'});
    else if(isconsume('/')) v.push_back({nums, '/'});
    else {
      v.push_back({nums, '0'});
      break;
    }
  }
 
  queue<vector<pic>> q; q.push(v);
  set<vector<pic>> used;
 
  set<int> ret;
 
  for(;!q.empty(); q.pop()) {
    auto const& p = q.front();
    if(p.size() == 1) {
      for(auto&& e: p[0].first)
        ret.insert(e);
      continue;
    }
    assert(p.size());
    rep(i, p.size() - 1)
      push_trans(q, used, p, i, p[i].first, p[i+1].first, p[i].second, p[i+1].second);
  }
 
  return ret;
}
 
int solve() {
  it = s.begin(), end = s.end();
  return expr().size();
}
 
}
 
int main() {
 
  while(cin >> solver::s) {
    if(solver::s == "#") break;
    cout << solver::solve() << endl;
  }
   
  return 0;
}
typedef pair<int, int> PII;
 
const int SIZE = 100;
 
char line[SIZE], t[SIZE];
set<int> memo[SIZE][SIZE];
set<PII> used;
  
void eval(int const l, int const r) {
    
  set<int>& ret = memo[l][r];
  if(l == r) return;
    
  if(used.count(MP(l, r))) return;
  used.insert(MP(l, r));
    
  // check is digits [l, r)
  bool digits = true;
  for(int i=l; i<r; i++) { digits = digits && isdigit(line[i]); }
  strncpy(t, line+l, r-l); t[r-l] = 0;
  if(digits) { ret.insert(atoi(t)); return; }
    
  int par = 0; bool check = true;
  for(int i=l; i<r-1; i++) {
    if(line[i] == '(') { par ++; }
    if(line[i] == ')') { par --; }
    if(par == 0) { check = false; }
  }
    
  // "(expr)"
  if(check) {
    if(line[r-1] == ')') {
      eval(l+1, r-1); ret = memo[l+1][r-1]; return;
    }
  }
    
  par = 0;
  // parsing loop
  for(int i=l; i<r; i++) {
    if(line[i] == '(') { par ++; }
    if(line[i] == ')') { par --; }
    if(par != 0) continue;
      
    switch(line[i]) {
    case '+': case '-': case '*': case '/':
      eval(l, i); eval(i+1, r);
      EACH(ia, memo[l][i]) EACH(ib, memo[i+1][r]) {
        if(line[i] == '+') { ret.insert(*ia + *ib); }
        if(line[i] == '-') { ret.insert(*ia - *ib); }
        if(line[i] == '*') { ret.insert(*ia * *ib); }
        if(line[i] == '/' && *ib!=0) { ret.insert(*ia / *ib); }
      }
      break;
    }
  } // for parsing loop
}
  
int main() {
    
  while(1) {
    scanf("%s\n", line);
    for(int i=0; i<SIZE; i++)
      for(int j=0; j<SIZE; j++)
        memo[i][j].clear();
      
    used.clear();
      
    if(0==strcmp(line, "#")) break;
    eval(0, strlen(line));
    cout << memo[0][strlen(line)].size() << endl;
  }
    
  return 0;
}