import os

import pandas as pd
from flask import Flask, request, jsonify
from img_search import ImgSearch
import time

img_search = ImgSearch()

app = Flask(__name__)
UPLOAD_FOLDER = '/mnt/data/img_data/tmp_img/'
if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)


@app.route('/img_search/similarity', methods=['POST'])
def search_amazon_similarity():
    print(request.form)  # 打印所有收到的表单数据
    print(request.files)  # 打印所有收到的文件
    file_list = ['file1', 'file2']
    files = []
    for file_key in file_list:
        if file_key in request.files:
            files.append(request.files[file_key])
    if len(files) != 2:
        return jsonify({"error": "只能上传 2 个文件"}), 400
    else:
        # file1_path
        timestamp = time.time()
        milli_timestamp = round(timestamp * 1000)  # 将秒级的时间戳转换为毫秒级
        filename = f"{milli_timestamp}_{files[0].filename}"
        file1_path = os.path.join(UPLOAD_FOLDER, filename)
        files[0].save(file1_path)
        # file2_path
        timestamp = time.time()
        milli_timestamp = round(timestamp * 1000)  # 将秒级的时间戳转换为毫秒级
        filename = f"{milli_timestamp}_{files[1].filename}"
        file2_path = os.path.join(UPLOAD_FOLDER, filename)
        files[1].save(file2_path)

        similarity = img_search.cosine_similarity_two_img(
            file1_path=file1_path,
            file2_path=file2_path
        )
        return jsonify(
                {"similarity": similarity}
            )


@app.route('/img_search', methods=['POST'])
def search_amazon():
    print(request.form)  # 打印所有收到的表单数据
    print(request.files)  # 打印所有收到的文件
    site_name = request.form.get('site_name', 'us')  # 默认值为 'us'
    img_type = request.form.get('img_type', 'amazon_inv')  # 默认值为 'amazon_inv'
    # search_key = request.form.get('search_key', 'asin')  # 默认值为 'asin'
    search_key = request.form.get('search_key', 'img_unique')  # 默认值为 'img_unique' -- file
    search_value = request.form.get('search_value', 'B0BBNQCXZL')  # 默认值为 'B0BBNQCXZL' --
    top_k = int(request.form.get('top_k', 100))  # 默认值为 100
    df_similarities_list = []
    # 获取文件并存储
    if search_key == 'file':
        # file = request.files['file']  # 获取名为 "file" 的文件流参数
        file_list = ['file1', 'file2', 'file3', 'file4', 'file5']
        files = []
        for file_key in file_list:
            if file_key in request.files:
                files.append(request.files[file_key])
        # 如果传入的文件超过 5 个，可以考虑返回错误或警告
        if len(files) > 5:
            return jsonify({"error": "最多只能上传 5 个文件"}), 400
        # files = request.files.getlist('file')  # 获取名为 "file" 的多个文件
        for file in files:
            timestamp = time.time()
            milli_timestamp = round(timestamp * 1000)  # 将秒级的时间戳转换为毫秒级
            filename = f"{milli_timestamp}_{file.filename}"
            file_path = os.path.join(UPLOAD_FOLDER, filename)
            file.save(file_path)

            # 打印日志，便于调试
            print("file saved at:", file_path)

            # 使用每张图片进行搜索
            df_similarities = img_search.search_api(
                site_name=site_name, img_type=img_type, search_key=search_key,
                search_value=search_value, top_k=top_k, file_path=file_path
            )
            df_similarities_list.append(df_similarities)
            # similarities_dict_list.append({
            #     "similarities_dict": similarities_dict
            # })

    else:
        file_path = ''
        # 使用每张图片进行搜索
        df_similarities = img_search.search_api(
            site_name=site_name, img_type=img_type, search_key=search_key,
            search_value=search_value, top_k=top_k, file_path=file_path
        )
        df_similarities_list.append(df_similarities)
        # similarities_dict_list.append({
        #     "similarities_dict": similarities_dict
        # })
    df_result = pd.concat(df_similarities_list)
    df_result = df_result.loc[~(df_result.similarity.isna())]  # 过滤相似度为NaN的记录
    print(f"df_result.columns: {df_result.columns}")
    if df_result.shape[0]:
        df_result.sort_values(["similarity"], ascending=[False], inplace=True)
    # 新增 id 列，从 1 开始
    df_result['id'] = range(1, len(df_result) + 1)
    df_result = df_result.iloc[: top_k]
    result = []
    for _, row in df_result.iterrows():
        row_dict = {'id': row['id']}  # 创建一个字典，首先将 id 添加为 key
        for col in df_result.columns:
            if col != 'id':  # 排除 id 列
                row_dict[col] = row[col]  # 将其他列添加到字典中
        result.append(row_dict)
    return jsonify(
        result
    )


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=10001)
