核心代码如下:

  1. # create the base pre-trained model
  2. input_tensor = Input(shape=(299, 299, 3))
  3. base_model = Xception(include_top=True, weights='imagenet', input_tensor=None, input_shape=None)
  4. plot_model(base_model, to_file='xception_model.png')
  5. base_model.layers.pop()
  6. base_model.outputs = [base_model.layers[-1].output]
  7. base_model.layers[-1].outbound_nodes = []
  8. base_model.output_layers = [base_model.layers[-1]]
  9. feature = base_model
  10. img1 = Input(shape=(299, 299, 3), name='img_1')
  11. img2 = Input(shape=(299, 299, 3), name='img_2')
  12. feature1 = feature(img1)
  13. feature2 = feature(img2)
  14. # Three loss functions
  15. category_predict1 = Dense(100, activation='softmax', name='ctg_out_1')(
  16. Dropout(0.5)(feature1)
  17. )
  18. category_predict2 = Dense(100, activation='softmax', name='ctg_out_2')(
  19. Dropout(0.5)(feature2)
  20. )
  21. dis = Lambda(eucl_dist, name='square')([feature1, feature2])
  22. judge = Dense(2, activation='softmax', name='bin_out')(dis)
  23. model = Model(inputs=[img1, img2], outputs=[category_predict1, category_predict2, judge])
  24. model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),
  25. loss={
  26. 'ctg_out_1': 'categorical_crossentropy',
  27. 'ctg_out_2': 'categorical_crossentropy',
  28. 'bin_out': 'categorical_crossentropy'},
  29. loss_weights={
  30. 'ctg_out_1': 1.,
  31. 'ctg_out_2': 1.,
  32. 'bin_out': 0.5
  33. },
  34. metrics=['accuracy'])