接上节,本节分析最后一个模块:6,让经过训练的模型进行预测:
当训练好一个模型并“evalute”它是有效的之后,就可以该模型对无标签样本(即包含特征但不包含标签的样本)进行一些预测。
手动提供下列三个无标签样本:
predict_x = {'SepalLength': [5.1, 5.9, 6.9],'SepalWidth': [3.3, 3.0, 3.1],'PetalLength': [1.7, 4.2, 5.4],'PetalWidth': [0.5, 1.5, 2.1],}
同evaluate方法一样,Estimator 类也提供一个 predict 方法,用于对特征数据进行预测。
predict 方法使用与evaluate方法相同的输入函数eval_input_fn,唯一不同的是,predict方法不要传递标签值(labels)给eval_input_fn函数,因为labels值是我们希望predict方法预测出的,如下图所示:
classifier.predict
predict 方法返回一个 Python 可迭代对象predictions,为每个样本生成一个预测结果字典,此字典包含几个键,如所示:
{'logits': array([ 0.01958101, -0.13283071, 0.02277012], dtype=float32), 'probabilities': array([0.34942693, 0.30003005, 0.35054305], dtype=float32), 'class_ids': array([2], dtype=int64), 'classes': array([b'2'], dtype=object)}
其中:
probabilities 键存储的是一个浮点值组成的列表,每个浮点值表示输入样本是特定鸢尾花品种的概率
class_ids 键存储的是一个 1 元素数组,用于标识可能性最大的品种
最后,遍历predictions, 打印出每个样本的预测结果
到此,鸢尾花分类程序详细解读完毕。













网友评论