基于提供的main5.py代码,以下是详细的代码结构梳理:


[[E高速列车轴承智能故障诊断问题-论文_任务二_模型架构]]

📁 代码整体架构

1. 导入模块与配置 (Lines 1-12)

1
2
3
4
5
6
7
8
import os, glob, math, argparse, random, warnings
from pathlib import Path
import numpy as np, scipy.io as sio, scipy.signal as sig
from scipy.stats import kurtosis
import yaml, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.manifold import TSNE, sklearn.metrics
import matplotlib.pyplot as plt

🔧 核心功能模块

2. 配置与轴承机理 (Lines 15-39)

  • load_cfg(p): 加载YAML配置文件,设置预处理默认值
  • bearing_freqs(rpm, n, d, D, theta_deg): 计算轴承故障特征频率(BPFI、BPFO、BSF、FTF)

3. 信号预处理模块 (Lines 42-128)

  • detrend(x): 线性去趋势
  • bandpass(x, fs, f1, f2): 带通滤波
  • envelope(x): 包络提取
  • zscore(x): Z-score归一化
  • resample_if_needed(x, fs_src, fs_tgt): 重采样
  • windowing(x, fs, win_sec, overlap): 信号分段
  • spectral_kurtosis_band(x, fs): 谱峭度滤波频带选择
  • order_resample(x, fs, rpm, spr): 阶次重采样

4. 数据加载模块 (Lines 131-183)

  • read_mat_any(path): 通用MAT文件读取,兼容多种变量命名
  • ensure_1d(x): 确保数据为1D格式
  • _pick_multich_from_rec(rec, prefer_keys, max_ch): 多通道数据选择

📊 数据集类

5. 原始数据集 (Lines 186-302)

  • SourceBearingDS: 源域数据集(MAT文件在线预处理)
  • TargetBearingDS: 目标域数据集(MAT文件在线预处理)

6. 预处理数据集 (Lines 305-357)

  • PreprocessedSourceDS: 源域预处理数据集(从NPY文件加载)
  • PreprocessedTargetDS: 目标域预处理数据集(从NPY文件加载)
  • _read_index_csv(csv_path): 读取预处理索引文件

🧠 神经网络模块

7. 网络组件 (Lines 360-449)

  • MechanismFilterBank: 机理感知滤波器组
  • ChannelSE: 通道注意力机制
  • FeatNet1D: 1D卷积特征提取器
  • GateExplain: 可解释门控机制
  • Classifier: 分类器
  • BearingNet: 主网络架构

8. 模型工具 (Lines 452-472)

  • safe_load(model, ckpt_path): 安全加载模型权重
  • _strip_module_prefix(state_dict): 处理分布式训练前缀

📈 损失函数与训练

9. 损失函数 (Lines 475-488)

  • coral_loss(source, target): CORAL域适配损失
  • entropy_minimization(logits): 熵最小化损失

10. 训练相关工具 (Lines 491-548)

  • build_centers(cfg, rpm): 构建机理频率中心
  • set_seed(s): 设置随机种子
  • compute_class_weights(ds): 计算类别权重
  • evaluate(model, loader): 模型评估
  • recalibrate_bn(model, dl_tgt): 批归一化重校准

11. 核心训练流程 (Lines 551-757)

  • train_source(cfg, out): 源域监督训练
  • adapt_coral(cfg, ckpt, out): CORAL+SHOT域适配
  • infer_target(cfg, ckpt, out): 目标域推理
  • visualize_embeddings(cfg, ckpt, out): t-SNE可视化

🎨 可视化与分析模块

12. 特征提取与分析 (Lines 760-870)

  • extract_source_features(cfg, out_dir): 提取源域机理特征(FBE等)
  • train_feat_classifier(feat_csv, out_dir): 训练随机森林基线
  • viz_features_from_csv(csv_path, out_dir): 特征可视化

13. 包络谱可视化 (Lines 873-968)

  • viz_envelope_examples(cfg, out_dir): 生成包络谱示例图,双子图布局(0-1000Hz主频段 + 1000-3000Hz高频段)

14. 其他可视化工具 (Lines 971-1076)

  • plot_confusion(cm, out_path): 混淆矩阵绘制
  • viz_source_confusion(cfg, out): 源域混淆矩阵
  • viz_gate_channel_weights(cfg, out): 门控权重与通道注意力可视化
  • full_viz_report(cfg, out): 一键生成完整可视化报告

🚀 主程序入口

15. 命令行接口 (Lines 1079-1121)

1
2
3
if __name__=="__main__":
# 参数解析:--cfg, --stage, --lambda_coral, --shot, --out, --feat_csv, --device
# 支持阶段:train_source, adapt, infer, viz, feat, feat_clf, viz_report

📋 流程依赖关系

graph TD
    A[数据预处理] --> B[特征提取 feat]
    B --> C[RF基线训练 feat_clf]
    A --> D[源域训练 train_source]
    D --> E[域适配 adapt]
    E --> F[目标域推理 infer]
    C --> G[可视化报告 viz_report]
    F --> G
    E --> H[t-SNE可视化 viz]

🎯 关键特点

  1. 模块化设计: 清晰分离数据处理、模型定义、训练、可视化
  2. 多模式支持: 同时支持在线MAT处理和预处理NPY数据
  3. 机理融合: 集成轴承故障机理的滤波器组和门控机制
  4. 域适配: 实现CORAL+SHOT的无监督域适配
  5. 可解释性: 提供特征重要度、注意力权重、包络谱等多种可视化
  6. 容错性: 安全的模型加载、动态通道处理、异常处理

该代码架构适合轴承故障诊断的迁移学习任务,既保证了算法先进性,又提供了丰富的可解释性分析工具。


http://example.com/posts/138.html
作者
司马吴空
发布于
2026年3月30日
许可协议