1def _my_model_fn(features, labels, mode):
2 my_head = tf.estimator.MultiClassHead(n_classes=3)
3 logits = tf.keras.Model(...)(features)
4
5 return my_head.create_estimator_spec(
6 features=features,
7 mode=mode,
8 labels=labels,
9 optimizer=tf.keras.optimizers.Adagrad(lr=0.1),
10 logits=logits)
11
12my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
1def _my_model_fn(features, labels, mode):
2 my_head = tf.estimator.MultiClassHead(n_classes=3)
3 logits = tf.keras.Model(...)(features)