tf_dome2.py 3.42 KB
Newer Older
1  
abel_cjy committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
import psycopg2

# 加载之前保存的模型
from tensor_flow.tf_dome1 import tokenizer, max_length

model = load_model('text_classification_model.h5')

#连接 PostgreSQL 数据库
connection = psycopg2.connect(
    host="192.168.10.223",
    port="5432",  # 替换成你的 PostgreSQL 端口
    database="selection",
    user="postgres",
    password="fazAqRRVV9vDmwDNRNb593ht5TxYVrfTyHJSJ3BS"
)

#从数据库读取数据
query = "select  asin,title,product_description,describe  from us_asin_detail_month_2023_12  where date_info ='2023-12'  and spider_int = 2  and product_description !=''  limit " \
        ""
new_data = pd.read_sql(query, connection)

#关闭数据库连接
connection.close()

#合并文本字段
new_texts = new_data['describe'] + ' ' + new_data['title'] + ' ' + new_data['product_description']
new_texts = new_texts.fillna('')
new_texts = new_texts.astype(str)

#使用之前的 Tokenizer 对象进行文本序列转化为数字序列,将新的文本数据转换为数字序列
new_sequences = tokenizer.texts_to_sequences(new_texts)
#将上一步得到的整数序列进行填充(padding)和截断(truncating),使得它们具有相同的长度
#这样处理之后,new_padded 就是填充和截断后的数字序列,可以输入到神经网络模型中进行预测。这个过程保证了输入数据的一致性,使得所有的输入序列长度都是相同的,方便模型的处理。
new_padded = pad_sequences(new_sequences, maxlen=max_length, padding='post', truncating='post')

# 使用模型进行新的预测
new_predictions = model.predict(new_padded)

# 设置阈值
threshold = 0.5

# 存储匹配的 ASIN 和字段信息
matching_asins_info = []

# 遍历新数据中的每个样本
for i in range(len(new_data)):
    # 获取样本信息
    asin = new_data["asin"].iloc[i]
    print('new_predictions[0]:',new_predictions[i][0])
    # 如果模型预测为与圣诞节相关
    if new_predictions[i][0] >= threshold:
        print('asin:', asin)
        # 记录匹配的 ASIN 和字段信息
        matching_asins_info.append({
            'ASIN': asin,
            'Probabilities': new_predictions[i],
            'MatchingFields': {
                'Title': new_data['title'].iloc[i],
                'Description': new_data['describe'].iloc[i],
                'ProductDescription': new_data['product_description'].iloc[i]
            }
        })

# 将匹配的信息保存到文件
output_file_path = 'E:/BaiduNetdiskDownload/选品大数据/推荐系统/matching_asins_info2.csv'
with open(output_file_path, 'w', encoding='utf-8') as output_file:
    for info in matching_asins_info:
        output_file.write(f"ASIN: {info['ASIN']}\n")
        #output_file.write(f"Is Christmas Label: {info['IsChristmasLabel']}\n")

        matched_field = 'Title'  # 默认为 Title
        if info['MatchingFields'].get('Description'):
            matched_field = 'Description'
        elif info['MatchingFields'].get('ProductDescription'):
            matched_field = 'Product Description'

        output_file.write(f"Matched Field: {matched_field}\n")

        # 检查键是否存在,如果存在则写入,否则跳过
        matched_content = info['MatchingFields'].get(matched_field, 'N/A')
        output_file.write(f"Matched Content: {matched_content}\n")
        output_file.write('-' * 50 + '\n')