Files
NewQuant/test/test.ipynb

221 lines
124 KiB
Plaintext
Raw Normal View History

2025-09-24 23:14:14 +08:00
{
"cells": [
{
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-09-22T11:16:11.127998Z",
"start_time": "2025-09-22T11:16:11.126254Z"
}
},
"cell_type": "code",
"source": "",
"id": "initial_id",
"outputs": [],
"execution_count": null
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-09-22T11:16:37.464945Z",
"start_time": "2025-09-22T11:16:37.392136Z"
}
},
"cell_type": "code",
"source": [
"from src.algo.TrendLine import calculate_latest_trendline_values\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def get_trendlines(prices: np.ndarray):\n",
" \"\"\"\n",
" 计算给定价格序列的上下趋势线。\n",
"\n",
" Args:\n",
" prices: 包含价格的一维numpy数组。\n",
"\n",
" Returns:\n",
" 一个元组包含两个numpy数组\n",
" - upper_trendline: 上趋势线的值\n",
" - lower_trendline: 下趋势线的值\n",
" \"\"\"\n",
" n = len(prices)\n",
" x = np.arange(n)\n",
"\n",
" # --- 计算上趋势线 ---\n",
" high_point_idx = np.argmax(prices)\n",
" high_point_price = prices[high_point_idx]\n",
"\n",
" best_upper_slope = None\n",
" min_upper_distance_sum = float('inf')\n",
"\n",
" # 迭代除最高点外的所有点,以确定最佳斜率\n",
" for i in range(n):\n",
" if i == high_point_idx:\n",
" continue\n",
"\n",
" # 计算候选斜率:(y2 - y1) / (x2 - x1)\n",
" candidate_slope = (prices[i] - high_point_price) / (i - high_point_idx)\n",
"\n",
" # 通过该斜率和最高点确定一条候选直线\n",
" # y = m*x + c => c = y - m*x\n",
" intercept = high_point_price - candidate_slope * high_point_idx\n",
" candidate_line = candidate_slope * x + intercept\n",
"\n",
" # 检查该直线是否在所有价格点之上\n",
" if np.all(candidate_line >= prices - 1e-9): # 使用一个小的容差\n",
" # 计算所有点到该直线的垂直距离之和\n",
" distance_sum = np.sum(candidate_line - prices)\n",
"\n",
" # 如果距离更小,则更新最佳斜率\n",
" if distance_sum < min_upper_distance_sum:\n",
" min_upper_distance_sum = distance_sum\n",
" best_upper_slope = candidate_slope\n",
"\n",
" # 使用最佳斜率计算最终的上趋势线\n",
" upper_intercept = high_point_price - best_upper_slope * high_point_idx\n",
" upper_trendline = best_upper_slope * x + upper_intercept\n",
"\n",
"\n",
" # --- 计算下趋势线 ---\n",
" low_point_idx = np.argmin(prices)\n",
" low_point_price = prices[low_point_idx]\n",
"\n",
" best_lower_slope = None\n",
" min_lower_distance_sum = float('inf')\n",
"\n",
" # 迭代除最低点外的所有点,以确定最佳斜率\n",
" for i in range(n):\n",
" if i == low_point_idx:\n",
" continue\n",
"\n",
" candidate_slope = (prices[i] - low_point_price) / (i - low_point_idx)\n",
" intercept = low_point_price - candidate_slope * low_point_idx\n",
" candidate_line = candidate_slope * x + intercept\n",
"\n",
" # 检查该直线是否在所有价格点之下\n",
" if np.all(candidate_line <= prices + 1e-9):\n",
" # 计算所有点到该直线的垂直距离之和(距离是正值)\n",
" distance_sum = np.sum(prices - candidate_line)\n",
"\n",
" if distance_sum < min_lower_distance_sum:\n",
" min_lower_distance_sum = distance_sum\n",
" best_lower_slope = candidate_slope\n",
"\n",
" # 使用最佳斜率计算最终的下趋势线\n",
" lower_intercept = low_point_price - best_lower_slope * low_point_idx\n",
" lower_trendline = best_lower_slope * x + lower_intercept\n",
"\n",
" return upper_trendline, lower_trendline\n",
"\n",
"\n",
"# --- 示例和可视化 ---\n",
"\n",
"# 1. 生成一个长度为 n 的价格序列 (n=50)\n",
"n = 50\n",
"np.random.seed(42)\n",
"base_prices = 100 + np.cumsum(np.random.randn(n)) * 1.5\n",
"# 增加一些波动和噪声\n",
"noise = np.random.randn(n) * 2\n",
"sample_prices = base_prices + noise\n",
"\n",
"# 2. 调用算法计算趋势线\n",
"upper_line, lower_line = get_trendlines(sample_prices)\n",
"\n",
"print(upper_line)\n",
"print(lower_line)\n",
"print(calculate_latest_trendline_values(sample_prices))\n",
"\n",
"# 3. 可视化结果\n",
"plt.style.use('seaborn-v0_8-darkgrid')\n",
"plt.figure(figsize=(14, 7))\n",
"\n",
"# 绘制原始价格序列\n",
"plt.plot(sample_prices, 'o-', label='价格序列', color='skyblue', markersize=5)\n",
"\n",
"# 绘制趋势线\n",
"plt.plot(upper_line, label='上趋势线', color='red', linestyle='--', linewidth=2)\n",
"plt.plot(lower_line, label='下趋势线', color='green', linestyle='--', linewidth=2)\n",
"\n",
"# 标记最高点和最低点\n",
"high_idx = np.argmax(sample_prices)\n",
"low_idx = np.argmin(sample_prices)\n",
"plt.plot(high_idx, sample_prices[high_idx], 'v', color='darkred', markersize=10, label='最高价')\n",
"plt.plot(low_idx, sample_prices[low_idx], '^', color='darkgreen', markersize=10, label='最低价')\n",
"\n",
"\n",
"plt.title('价格序列及算法生成的趋势线', fontsize=16)\n",
"plt.xlabel('时间步', fontsize=12)\n",
"plt.ylabel('价格', fontsize=12)\n",
"plt.legend(fontsize=12)\n",
"plt.grid(True)\n",
"# 设置中文字体\n",
"plt.rcParams['font.sans-serif'] = ['SimHei'] # 'SimHei' 是黑体\n",
"plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题\n",
"plt.show()"
],
"id": "67848260b39bf0dc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[113.49570079 112.9597348 112.42376882 111.88780283 111.35183685\n",
" 110.81587087 110.27990488 109.7439389 109.20797291 108.67200693\n",
" 108.13604095 107.60007496 107.06410898 106.52814299 105.99217701\n",
" 105.45621103 104.92024504 104.38427906 103.84831307 103.31234709\n",
" 102.77638111 102.24041512 101.70444914 101.16848315 100.63251717\n",
" 100.09655119 99.5605852 99.02461922 98.48865323 97.95268725\n",
" 97.41672127 96.88075528 96.3447893 95.80882331 95.27285733\n",
" 94.73689135 94.20092536 93.66495938 93.12899339 92.59302741\n",
" 92.05706143 91.52109544 90.98512946 90.44916347 89.91319749\n",
" 89.37723151 88.84126552 88.30529954 87.76933355 87.23336757]\n",
"[96.89424075 96.54985396 96.20546716 95.86108036 95.51669356 95.17230676\n",
" 94.82791996 94.48353317 94.13914637 93.79475957 93.45037277 93.10598597\n",
" 92.76159917 92.41721237 92.07282558 91.72843878 91.38405198 91.03966518\n",
" 90.69527838 90.35089158 90.00650479 89.66211799 89.31773119 88.97334439\n",
" 88.62895759 88.28457079 87.940184 87.5957972 87.2514104 86.9070236\n",
" 86.5626368 86.21825 85.8738632 85.52947641 85.18508961 84.84070281\n",
" 84.49631601 84.15192921 83.80754241 83.46315562 83.11876882 82.77438202\n",
" 82.42999522 82.08560842 81.74122162 81.39683483 81.05244803 80.70806123\n",
" 80.36367443 80.01928763]\n",
"(np.float64(87.23336756933077), np.float64(80.01928763164909))\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 1400x700 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABH8AAAJvCAYAAADiLX0jAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XWYVNUbwPHvnd5mC5BGYOkUEBUEEcWgQxAkRPmBgEgoKIiKlCCIhEqKdEinIEhKSZdId2339P39se7IsEnuLryf5+GBOffcO2dm5+zufXnPexRVVVWEEEIIIYQQQgghxGNJk9UDEEIIIYQQQgghhBAPjwR/hBBCCCGEEEIIIR5jEvwRQgghhBBCCCGEeIxJ8EcIIYQQQgghhBDiMSbBHyGEEEIIIYQQQojHmAR/hBBCCCGEEEIIIR5jEvwRQgghhBBCCCGEeIxJ8EcIIYQQQgghhBDiMSbBHyGEEEIIIYQQQojHmAR/hBBCCPHYiouLw+l0ptvHZrPhcDge0YjE3XA4HISGhhIXF5dmnxs3bjzCEQkhhBA5kwR/hBBCiEyaPHkyXbt2ZdeuXSmOxcTE0KlTJ/744w9UVc3U9ZxOJ6NHj2blypWpHj98+PADv7E9cuQI7733HgMHDnyg183I4cOH+eWXXzhy5Mgjfd4+ffpQr1499uzZk2afRYsW8fzzzzN27NhMXzejgBLA119/TeXKldm7d2+mr5uWGzdu8NZbb9GvX7/7vlZaunfvTv/+/VFVFZvNxpgxYwgLC8vUuX/99Rfbt2/HYrGkevzIkSOsWLGCxMTEuxpTaGgoNWvWZMyYMakev3z5MnXq1KFr1653fd0uXbrwww8/pHr8gw8+oGfPnncdFDxy5AgHDx5M9djVq1eZNm0aa9asydTnRwghhHiQdFk9ACGEECKn2LhxI0eOHKFLly4pjq1cuZI///wTi8VC3bp1M3U9jUbDzJkzqV+/Po0aNUpx/JtvvuHw4cMsXryYMmXK3Pf4AQIDA9mxYwd+fn4MGTIEjSbj/wfq378/Hh4e6fZVVRWHw0FCQgKjR49OcXzPnj2MGTOGzz77jAoVKrgdmzdvHlFRUej1+kyNJ1nu3Llp2LBhmsfXrl3L1q1bKVWqFKqqugJATqcTq9VKqVKl8PHxYdq0acTHx1OsWDG3IJHD4SAxMZG6deuiKIqrffny5UyfPp1vvvmGsmXLpvn8Go2GhIQEdLrM/7oVFxfHuHHjaN68OaVKlXI7dvjwYbRabaavdbcOHTqE2WxGURT+/PNPpk6dypo1a5gyZQrFixdP99yVK1eyaNEivvvuO958880Ux7ds2cIPP/zA/v37+frrrzM9JoPBAIDJZEr1+JIlS1BVlQYNGmT6mpD0tdmyZQsFCxZM9XhoaChhYWF3/X7/8MMPbNmyhQEDBtChQwe3Y6qq8u2331KwYEHefPNNzp8/z5o1a+jatetdfUaEEEKIeyE/aYQQQohMuHz5MseOHaN69epUrlzZ7ZjVamX69OlA0g36sWPHKFeuXKauazAY8PT0TNF+8uRJDhw4QIMGDe4q8HPy5ElOnDiRbp+goCDCwsKYMmUKuXPndrU7HA5sNht2u50qVaq4XsPVq1cxmUyuG+EdO3Zgt9upWbOm66bV6XS6giWpMRqNAPj4+KQ4tnLlSv755x/0ej2KohATE4NGo8Hb2zvVa8XGxuJwOBg5cmSar/HixYt88cUXrvfkzhtxgHHjxrFnzx6uXbsGwCeffJLqtY4fP+52c37p0iXOnDnDW2+9xccff8y7776b6nl6vR7grm7sDxw4wKxZs1i9ejW//fYbfn5+wH9BkIcZJPD09MRutwNQp04dvv32Wz799FPatWvHjBkzUgSjbnf69Gm8vb155ZVXUj2+a9cutFotnTt3znAcZrMZu92O0Wh0vW5FUbBYLJjNZjw9PdHr9SQmJjJ//nwURWHNmjWsXbvW7To9e/Z0G/ORI0ewWCxotVoSEhKApGDbvn378PDwoHDhwpw5cwZIWgqoKAobN250zQmbzUbDhg1dX9c7nT9/nm3btlG9enXeeeedFMcDAwMBKFCgAAB79+5lwoQJ3Lp1664CYkIIIcS9kOCPEEIIkQlz5szB6XSmeqM/c+ZMrl27Rq9evfjtt9/45JNPWLhwIb6+vhle9/aMkttNmDABgA0bNqQINgG8/PLLqWbYbN68me+//x5PT0/XjXNqcuXKxYwZM9za7Ha760/fvn1dwZ85c+a4+sTExFCtWjV8fX1dAa87OZ1OXn/9dZxOJ4sXL8bPz881Fi8vrxT9FyxY4Pa4Zs2alChRIsX4AG7evEm9evUoVqxYqtlSAOHh4fzvf/8jPj6eBQsWEBISQrt27fj777/5448/8PLywmKxsHHjRubNm0eLFi0YMGAAqqpSt25dAgMDWbRoEQ6HA4vFkiLg0rNnT6pVq0avXr345ptviI+Pp3PnzlitVreARXIW0+1fY5vN5gpAeHh4pBj7/v37AejUqZMr8POomEwmoqOjXY8bNmxIYmIi06dPT3Wstzt79ix16tRJ9TMXGhrK4cOHqV27dpqZNrdbvHgxQ4YMcWubPn266/M2d+5cqlatys8//0xUVBQtW7Z0BVYiIyNZuHAhlStXThGsGjFiBAcOHHBrW7ZsGcuWLaNOnTr069ePVq1auR3v27cvRqORuLg4NBoNTZs2TXPc33//PQaDgeHDh6eaMZT8HiZnMbVq1Yo9e/Zw7NgxQkNDCQ4OzvC9EUIIIe6VBH+EEEKIDMTFxbFkyRLy589PnTp13I6dO3eOH374gTJlyvC///2PV199lRYtWvDee+8xY8aMFNkrNpuN48ePYzAY0Ov1OJ1OYmNjOX36NImJiZQvX55du3axceNGnnrqKT799FO385ctW8aWLVto3rx5qmNNzkro378/rVu3fnBvwr+Ss4ruXLp1O41GQ3R0NNHR0ZkKgN0pPj7eLSPpdtOmTcNqtdKvX780l4j99ttv3Lp1i/bt27sCZ5UqVeL48eMcO3aMV155BYvF4goE9erVyxWUqlGjBhcvXiQqKirdQMVzzz3HokWL+Pbbb+nYsSNLly7lq6++SrVvixYtUrT16NGDDz/8MEX7unXr0Ov1aX5979WkSZPcAjupiYyMJCEhwS2jyul0UrVqVaZOnUq9evXcPv/bt2/n1KlTmM1mYmJiMJvNbgHBBg0akCdPHtasWYPD4eDEiRM0a9YsxfO+8cYbvP/++67HyfWXTCYTTqeT7t2789prr9G4cWPMZjNFixbl5s2bTJs2jdKlSzNkyBBXgG3o0KEAfPbZZymeZ/DgwdjtdlfWUMuWLWnTpg2tWrVCr9dTpEgR/vzzTwICAmjSpAk+Pj7MnTsXgLp16+J0OtMM1u7bt4/169fTq1cvChYsSGhoKF988QXdunWjfPnybn1vrwk2dOhQHA4HBw4cICoqisaNG6f59RFCCCHuhwR/hBBCiAz89NNPxMbGUrp0abeAQ2RkJD169ECr1TJmzBi0Wi3FihVj7NixdO/enbZt2zJu3DiKFCniOicqKipFdsHatWtdS1Z27tzJF198gY+PD9evX8fpdPLGG2+4zv3iiy+oWrUqzz33XKpjTc5SSb5JXbduHadPn3bV00nr5jVZ7dq1KVmyZJrHk4sXlyhRIkUxak9PT1ewx2QykZCQkOHz3SkuLo6EhATy5s2b4tiNGzdYuHAhL774IrVq1UrzGm3btqVevXr4+/sTFxfHzZs3ee6556hVqxaFChXi7NmzAHz66adcuXKFmJgYYmJiAOjWrRt6vZ6AgIAMx1q4cGEmTpwIQOXKlfniiy8wGAyurI/ffvuNrVu30qNHD/Lnzw/8t7QutaV8e/bs4eLFi/j7+zNq1Ci3Y1arFUgKNt4ZEHQ4HFitVkqUKEGPHj1SHeuvv/7K9evX8ff3T/P1xMXFYbPZ3AqQOxwOHA4HZrOZQoUKuQV/Nm3axPz5812PN27cyMaNG12Pa9SoQXBwMHPnzsXT05NChQq5PZ/dbufAgQPUqFHDrf3pp5/m6aefBnAVkC5YsKCrlpa
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 9
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}