1220. Count Vowels Permutation

yPhantom 2019年10月14日 29次浏览

Given an integer n, your task is to count how many strings of length n can be formed under the following rules:

  • Each character is a lower case vowel ('a', 'e', 'i', 'o', 'u')
  • Each vowel 'a' may only be followed by an 'e'.
  • Each vowel 'e' may only be followed by an 'a' or an 'i'.
  • Each vowel 'i' may not be followed by another 'i'.
  • Each vowel 'o' may only be followed by an 'i' or a 'u'.
  • Each vowel 'u' may only be followed by an 'a'.

Since the answer may be too large, return it modulo 10^9 + 7.

Example 1:

Input: n = 1
Output: 5
Explanation: All possible strings are: "a", "e", "i" , "o" and "u".

Example 2:

Input: n = 2
Output: 10
Explanation: All possible strings are: "ae", "ea", "ei", "ia", "ie", "io", "iu", "oi", "ou" and "ua".

Example 3:

Input: n = 5
Output: 68

Constraints:

  • 1 <= n <= 2 * 10^4

Solution

还是动态规划,参考这个。根据Discuss里面的有向图,实际上我们就是要求这个图中有多少条长度为n的路径。我们用dp[n][char]代表长度为n并且终点为char的长度值,那么dp[0][char]都是0,dp[1][char]都是1,dp[2][char]就是第二个字母是char的路径。

以每条路径的倒数第二个节点的视角来看,我们求dp[n+1][charA],就是等于是求所有到charY有有向路径的dp[n][charY]之和。因此代码如下:

class Solution {
    public int countVowelPermutation(int n) {
        long[][] dp = new long[n + 1][5];
        int mod = 1000000007;
        
        for (int i = 0; i < 5; i++) {
            dp[1][i] = 1;
        }
        
        for (int i = 1; i < n ; i++) {
            dp[i + 1][0] = (dp[i][1] + dp[i][2] + dp[i][4]) % mod;
            dp[i + 1][1] = (dp[i][0] + dp[i][2]) % mod;
            dp[i + 1][2] = (dp[i][1] + dp[i][3]) % mod;
            dp[i + 1][3] = dp[i][2] % mod;
            dp[i + 1][4] = (dp[i][2] + dp[i][3]) % mod;
        }
        
        long res = 0;
        for (int i = 0; i < 5; i++) {
            res = (res + dp[n][i]) % mod;
        }
        return (int)res;
    }
}

注意,Math.pow是求得浮点数,此题要用long类型。


发现一个可以将O(n)的空间复杂度降为O(1)的版本,

class Solution {
    public int countVowelPermutation(int n) {
        long[] cur = new long[5];
        long[] next = new long[5];
        int mod = 1000000007;
        
        Arrays.fill(cur, 1);
        
        for (int i = 1; i < n ; i++) {
            next[0] = (cur[1] + cur[2] + cur[4]) % mod;
            next[1] = (cur[0] + cur[2]) % mod;
            next[2] = (cur[1] + cur[3]) % mod;
            next[3] = cur[2] % mod;
            next[4] = (cur[2] + cur[3]) % mod;
            long[] tmp = cur;
            cur = next;
            next = tmp; // 这里不能直接cur = next,不然会两个指针指向同一个数组。
        }
        
        long res = 0;
        for (int i = 0; i < 5; i++) {
            res = (res + cur[i]) % mod;
        }
        return (int)res;
    }
}

第一个版本实际上就是n+1 与 n两个数组在进行运算。因此可以得到第二个版本。

按这样写的话可以理解为长度为n的字符串是由长度为1不断累加上来的。