1、vp策略
This commit is contained in:
@@ -76,4 +76,162 @@ def calculate_latest_trendline_values(prices: np.ndarray) -> Tuple[Optional[floa
|
||||
lower_intercept = low_point_price - best_lower_slope * low_point_idx
|
||||
latest_lower_value = best_lower_slope * (n - 1) + lower_intercept
|
||||
|
||||
return latest_upper_value, latest_lower_value
|
||||
return latest_upper_value, latest_lower_value
|
||||
|
||||
|
||||
def calculate_latest_trendline_values_v2(prices: np.ndarray) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
【V3 最终修正版】
|
||||
根据给定的价格序列,仅计算并返回上、下趋势线在最后一个点的值。
|
||||
优化点:通过从两端向中间搜索凸包/凹包顶点的方式进行剪枝。
|
||||
"""
|
||||
n = len(prices)
|
||||
if n < 2:
|
||||
return None, None
|
||||
|
||||
x = np.arange(n)
|
||||
|
||||
# --- 计算上趋势线 ---
|
||||
high_point_idx = np.argmax(prices)
|
||||
high_point_price = prices[high_point_idx]
|
||||
best_upper_slope = None
|
||||
min_upper_distance_sum = float('inf')
|
||||
|
||||
# --- 修正点: 从最左侧向最高点搜索 ---
|
||||
left_max_price = -1.0
|
||||
for i in range(high_point_idx): # 遍历最高点左侧的所有点
|
||||
if prices[i] > left_max_price:
|
||||
# 这是一个候选点,进行计算
|
||||
candidate_slope = (high_point_price - prices[i]) / (high_point_idx - i)
|
||||
intercept = high_point_price - candidate_slope * high_point_idx
|
||||
candidate_line = candidate_slope * x + intercept
|
||||
if np.all(candidate_line >= prices - 1e-9):
|
||||
distance_sum = np.sum(candidate_line - prices)
|
||||
if distance_sum < min_upper_distance_sum:
|
||||
min_upper_distance_sum = distance_sum
|
||||
best_upper_slope = candidate_slope
|
||||
# 更新左侧迄今为止的最高点
|
||||
left_max_price = prices[i]
|
||||
|
||||
# --- 修正点: 从最右侧向最高点搜索 ---
|
||||
right_max_price = -1.0
|
||||
for i in range(n - 1, high_point_idx, -1): # 遍历最高点右侧的所有点
|
||||
if prices[i] > right_max_price:
|
||||
# 这是一个候选点,进行计算
|
||||
candidate_slope = (prices[i] - high_point_price) / (i - high_point_idx)
|
||||
intercept = high_point_price - candidate_slope * high_point_idx
|
||||
candidate_line = candidate_slope * x + intercept
|
||||
if np.all(candidate_line >= prices - 1e-9):
|
||||
distance_sum = np.sum(candidate_line - prices)
|
||||
if distance_sum < min_upper_distance_sum:
|
||||
min_upper_distance_sum = distance_sum
|
||||
best_upper_slope = candidate_slope
|
||||
# 更新右侧迄今为止的最高点
|
||||
right_max_price = prices[i]
|
||||
|
||||
if best_upper_slope is None:
|
||||
# 如果循环没有找到任何有效的线(例如,只有一个点或所有点在一条直线上)
|
||||
# 这种情况很少见,但为了稳健性,可以默认水平线
|
||||
best_upper_slope = 0.0
|
||||
|
||||
upper_intercept = high_point_price - best_upper_slope * high_point_idx
|
||||
latest_upper_value = best_upper_slope * (n - 1) + upper_intercept
|
||||
|
||||
# --- 计算下趋势线 (逻辑对称) ---
|
||||
low_point_idx = np.argmin(prices)
|
||||
low_point_price = prices[low_point_idx]
|
||||
best_lower_slope = None
|
||||
min_lower_distance_sum = float('inf')
|
||||
|
||||
# --- 修正点: 从最左侧向最低点搜索 ---
|
||||
left_min_price = float('inf')
|
||||
for i in range(low_point_idx):
|
||||
if prices[i] < left_min_price:
|
||||
candidate_slope = (low_point_price - prices[i]) / (low_point_idx - i)
|
||||
intercept = low_point_price - candidate_slope * low_point_idx
|
||||
candidate_line = candidate_slope * x + intercept
|
||||
if np.all(candidate_line <= prices + 1e-9):
|
||||
distance_sum = np.sum(prices - candidate_line)
|
||||
if distance_sum < min_lower_distance_sum:
|
||||
min_lower_distance_sum = distance_sum
|
||||
best_lower_slope = candidate_slope
|
||||
left_min_price = prices[i]
|
||||
|
||||
# --- 修正点: 从最右侧向最低点搜索 ---
|
||||
right_min_price = float('inf')
|
||||
for i in range(n - 1, low_point_idx, -1):
|
||||
if prices[i] < right_min_price:
|
||||
candidate_slope = (prices[i] - low_point_price) / (i - low_point_idx)
|
||||
intercept = low_point_price - candidate_slope * low_point_idx
|
||||
candidate_line = candidate_slope * x + intercept
|
||||
if np.all(candidate_line <= prices + 1e-9):
|
||||
distance_sum = np.sum(prices - candidate_line)
|
||||
if distance_sum < min_lower_distance_sum:
|
||||
min_lower_distance_sum = distance_sum
|
||||
best_lower_slope = candidate_slope
|
||||
right_min_price = prices[i]
|
||||
|
||||
if best_lower_slope is None:
|
||||
best_lower_slope = 0.0
|
||||
|
||||
lower_intercept = low_point_price - best_lower_slope * low_point_idx
|
||||
latest_lower_value = best_lower_slope * (n - 1) + lower_intercept
|
||||
|
||||
return latest_upper_value, latest_lower_value
|
||||
# ==============================================================================
|
||||
# 验证代码
|
||||
# ==============================================================================
|
||||
if __name__ == '__main__':
|
||||
import timeit
|
||||
from tqdm import tqdm
|
||||
|
||||
for i in tqdm(range(1000)):
|
||||
# 1. 生成一段模拟的价格序列
|
||||
np.random.seed(42)
|
||||
n_points = 200 # 使用一个较大的值来体现性能差异
|
||||
base_prices = 100 + np.cumsum(np.random.randn(n_points)) * 0.5
|
||||
noise = np.random.uniform(-1, 1, n_points)
|
||||
sample_prices = base_prices + noise
|
||||
|
||||
# print(f"--- 验证开始 (数据点: {n_points}) ---")
|
||||
|
||||
# 2. 调用 V1 和 V2 版本
|
||||
v1_upper, v1_lower = calculate_latest_trendline_values(sample_prices)
|
||||
v2_upper, v2_lower = calculate_latest_trendline_values_v2(sample_prices)
|
||||
|
||||
# print(f"V1 结果: Upper={v1_upper:.4f}, Lower={v1_lower:.4f}")
|
||||
# print(f"V2 结果: Upper={v2_upper:.4f}, Lower={v2_lower:.4f}")
|
||||
|
||||
# 3. 比对结果
|
||||
# 使用 np.isclose 来处理浮点数的微小误差
|
||||
results_match = np.isclose(v1_upper, v2_upper) and np.isclose(v1_lower, v2_lower)
|
||||
# print(f"\n结果是否一致: {results_match}")
|
||||
if not results_match:
|
||||
print("警告:V1 和 V2 版本计算结果不一致,请检查算法逻辑!")
|
||||
quit(-1)
|
||||
|
||||
|
||||
# 确保 timeit 可以访问到函数和数据
|
||||
setup_code = """
|
||||
import numpy as np
|
||||
from __main__ import calculate_latest_trendline_values, calculate_latest_trendline_values_v2, sample_prices
|
||||
"""
|
||||
|
||||
v1_time = timeit.timeit(
|
||||
"calculate_latest_trendline_values(sample_prices)",
|
||||
setup=setup_code,
|
||||
number=1000
|
||||
)
|
||||
|
||||
v2_time = timeit.timeit(
|
||||
"calculate_latest_trendline_values_v2(sample_prices)",
|
||||
setup=setup_code,
|
||||
number=1000
|
||||
)
|
||||
|
||||
print(f"V1 (原始) 版本总耗时: {v1_time:.6f} 秒")
|
||||
print(f"V2 (优化) 版本总耗时: {v2_time:.6f} 秒")
|
||||
|
||||
if v2_time > 0:
|
||||
speedup = v1_time / v2_time
|
||||
print(f"\n性能提升: V2 版本比 V1 版本快 {speedup:.2f} 倍")
|
||||
Reference in New Issue
Block a user