问题描述
Mike is a lawyer with the gift of photographic memory. He is so good with it that he can tell you all the numbers on a sheet of paper by having a look at it without any mistake. Mike is also brilliant with subsets so he thought of giving a challange based on his skill and knowledge to Rachael. Mike knows how many subset are possible in an array of N integers. The subsets may or may not have the different sum. The challenge is to find the maximum sum produced by any subset under the condition:
The elements present in the subset should not have any digit in common.
Note: Subset {12, 36, 45} does not have any digit in common and Subset {12, 22, 35} have digits in common.Rachael find it difficult to win the challenge and is asking your help. Can youhelp her out in winning this challenge?
输入
First Line of the input consist of an integer T denoting the number of test cases. Then T test cases follow. Each test case consist of a numbe N denoting the length of the array. Second line of each test case consist of N space separated integers denoting the array elements.
Constraints:
1 <= T <= 100
1 <= N <= 100
1 <= array elements <= 100000
输出
Corresponding to each test case, print output in the new line.
示例输入
1
3
12 22 35
示例输出
57
思路
int[ ] s :输入的数字
boolean[ ] visited:0~9是否已被访问
i:考虑第i个数
       第i层递归:考察第i个数的各位数字是否都没被访问。若是则此数可能是最终结果之一,此时有两种情况:1.选中此数;2.不选此数。若否,则只有1种情况,即上述情况2。这两种情况可通过m1、m2计算:
       m1=该数+(第i+1层递归的结果,将该数的各位设置为已访问)
       m2=第i+1层递归的结果,将该数的各位设置为未访问
       最后返回m1、m2中较大者
使用动态规划改进
       递归的解法存在重复计算,通过动态规划改进之。每次递归计算所依赖的参数有两个:i、visited数组。(输入数组s没有改变,可视作全局变量)。直观地,可构建dp数组为dp[i.length][2][2][2][2][2][2][2][2][2][2],但在编程时可使用更简便的方法:将后面的10个2压缩到1个维度,即dp[i.length][2^10]。此时第2个维度上的取值范围是0~1023,将其看作二进制数,则每一位上1表示已访问,0表示未访问。
       使用动态规划优化递归过程的方式为:在进行第i次递归时,先查找dp矩阵看当前参数下的结果是否已经被计算,若已计算则直接返回之,否则按照上述过程计算,并在返回结果之前更新dp矩阵。
代码
import java.util.*;
import java.lang.*;
class Main{
    public static int func(String[] s, int n, int i, boolean[] visited, int[][] dp)
    {
        //已考虑了全部n个数
        if(i==n)
        {
            return 0;
        }
        //mask就是压缩后的维度2
        int mask=0;
        for(int y=0;y<10;y++)
        {
            if(visited[y])
            {
                mask+=Math.pow(2,y);
            }
        }
        if(dp[i][mask]!=-1)
        {
            return dp[i][mask];
        }
        String current=s[i];
        int flag=0;
        int temp[]=new int[10];
        for(int k=0;k<current.length();k++)
        {
            //char转int要减48,也可以先转为String再转为int
            int num=(int)current.charAt(k)-48;
            if(visited[num])
            {
                flag=1;
                break;
            }
            temp[num]=1;
        }
        int m1=0;
        int m2=0;
        if(flag==0)
        {
            //选中第i个数(在可以选的情况下),则修改visited并递归计算
            for(int k=0;k<10;k++)
            {
                if(temp[k]==1)
                {
                    visited[k]=true;
                }
            }
            m1=Integer.parseInt(s[i])+func(s,n,i+1,visited,dp);
            //不选择第i个数,则先还原visited
            for(int k=0;k<10;k++)
            {
                if(temp[k]==1)
                {
                    visited[k]=false;
                }
            }
        }
        //还原后再递归计算
        m2=func(s,n,i+1,visited,dp);
        //比较二者、更新dp矩阵
        return dp[i][mask]=Math.max(m1,m2);
    }
    public static void main (String[] args) {
        Scanner sc=new Scanner(System.in);
        int t=sc.nextInt();
        for(int j=0;j<t;j++)
        {
            int n=sc.nextInt();
            String s[]=new String[n];
            for(int i=0;i<n;i++)
            {
                s[i]=sc.next();
            }
            boolean visited[]=new boolean[10];
            int dp[][]=new int[n][1024];
            for(int i=0;i<n;i++)
            {
                for(int k=0;k<1024;k++)
                {
                    dp[i][k]=-1;
                }
            }
            System.out.println(func(s,n,0,visited,dp));
        }
    }
}
python代码
t = int(input())
for _ in range(t):
    n = int(input())
    S = input().strip().split()
        
    A = []
    for s in S:
        m = 0
        for ss in set(s):
            m += 1 << int(ss)
        A.append((int(s), m))
    
    A.sort(key=lambda x: x[1])
    AA = []
    i = 0
    n = len(S)
    while i < n:
        a,m = A[i]
        i += 1
        while i < n and A[i][1] == m:
            if A[i][0] > a: a = A[i][0]
            i += 1
        AA.append((a,m))
    mx = 0
    R = [(0,0)]
    while R:
        RR = []
        for r,m in R:
            if m == 0b1111111111:
                continue
    
            for a,am in AA:
                if m & am == 0:
                    rr = r + a
                    if rr > mx: mx = rr
                    RR.append((rr, m+am))
    
        RR.sort(key=lambda x: x[1])
        R = []
        i = 0
        nn = len(RR)
        while i < nn:
            r,m = RR[i]
            i += 1
            while i < nn and RR[i][1] == m:
                if RR[i][0] > r: r = RR[i][0]
                i += 1
            R.append((r,m))
    
    print(mx)











网友评论