Source code for pyqstrat.strategy_components

# $$_ Lines starting with # $$_* autogenerated by jup_mini. Do not modify these
# $$_markdown
# # Strategy Components
# ## Purpose
# Helper components to build strategies with common use cases like VWAP entry and exit, and finite risk 
# $$_end_markdown
# $$_code
# $$_ %%checkall
import numpy as np
from dataclasses import dataclass
import math
from types import SimpleNamespace
from typing import Sequence, Callable
from pyqstrat.account import Account
from pyqstrat.pq_types import Contract, ContractGroup, Trade, Order, VWAPOrder
from pyqstrat.pq_types import MarketOrder, LimitOrder, TimeInForce
from pyqstrat.strategy import PriceFunctionType, StrategyContextType
from pyqstrat.pq_utils import assert_, get_child_logger, np_indexof_sorted


_logger = get_child_logger(__name__)


[docs] @dataclass class VectorIndicator: ''' An indicator created from a vector Args: vector: Vector with indicator values. Must be the same length as strategy timestamps ''' vector: np.ndarray
[docs] def __call__(self, contract_group: ContractGroup, timestamps: np.ndarray, indicator_values: SimpleNamespace, context: StrategyContextType) -> np.ndarray: return self.vector
[docs] @dataclass class VectorSignal: ''' A signal created from a vector that has boolean values Args: vector: Vector with indicator values. Must be the same length as strategy timestamps ''' vector: np.ndarray
[docs] def __call__(self, contract_group: ContractGroup, timestamps: np.ndarray, indicator_values: SimpleNamespace, parent_values: SimpleNamespace, context: StrategyContextType) -> np.ndarray: return self.vector
[docs] def get_contract_price_from_dict(price_dict: dict[str, dict[np.datetime64, float]], contract: Contract, timestamp: np.datetime64) -> float: assert_(contract.symbol in price_dict, f'{contract.symbol} not found in price_dict') ret = price_dict[contract.symbol].get(timestamp) if ret is None: return math.nan return ret
[docs] def get_contract_price_from_array_dict(price_dict: dict[str, tuple[np.ndarray, np.ndarray]], contract: Contract, timestamp: np.datetime64, allow_previous: bool) -> float: tup: tuple[np.ndarray, np.ndarray] | None = price_dict.get(contract.symbol) assert_(tup is not None, f'{contract.symbol} not found in price_dict') _timestamps, _prices = tup # type: ignore idx: int if allow_previous: idx = int(np.searchsorted(_timestamps, timestamp, side='right')) - 1 else: idx = np_indexof_sorted(_timestamps, timestamp) if idx == -1: return math.nan # if idx >= len(_prices): # import pdb # pdb.set_trace() return _prices[idx] # type: ignore
[docs] @dataclass class PriceFuncArrays: ''' A function object with a signature of PriceFunctionType. Takes three ndarrays of symbols, timestamps and prices ''' price_dict: dict[str, tuple[np.ndarray, np.ndarray]] allow_previous: bool
[docs] def __init__(self, symbols: np.ndarray, timestamps: np.ndarray, prices: np.ndarray, allow_previous: bool = False) -> None: assert_(len(timestamps) == len(symbols) and len(prices) == len(symbols), f'arrays have different sizes: {len(timestamps)} {len(symbols)} {len(prices)}') price_dict: dict[str, tuple[np.ndarray, np.ndarray]] = {} for symbol in np.unique(symbols): mask = (symbols == symbol) price_dict[symbol] = (timestamps[mask], prices[mask]) self.price_dict = price_dict self.allow_previous = allow_previous
[docs] def __call__(self, contract: Contract, timestamps: np.ndarray, i: int, context: StrategyContextType) -> float: price: float = 0. timestamp = timestamps[i] if contract.is_basket(): for _contract, ratio in contract.components: price += get_contract_price_from_array_dict(self.price_dict, _contract, timestamp, self.allow_previous) * ratio else: price = get_contract_price_from_array_dict(self.price_dict, contract, timestamp, self.allow_previous) return price
[docs] @dataclass class PriceFuncArrayDict: ''' A function object with a signature of PriceFunctionType and takes a dictionary of contract name -> tuple of sorted timestamps and prices Args: price_dict: a dict with key=contract nane and value a tuple of timestamp and price arrays allow_previous: if set and we don't find an exact match for the timestamp, use the previous timestamp. Useful if you have a dict with keys containing dates instead of timestamps >>> timestamps = np.arange(np.datetime64('2023-01-01'), np.datetime64('2023-01-04')) >>> price_dict = {'AAPL': (timestamps, [8, 9, 10]), 'IBM': (timestamps, [20, 21, 22])} >>> pricefunc = PriceFuncArrayDict(price_dict) >>> Contract.clear_cache() >>> aapl = Contract.create('AAPL') >>> assert(pricefunc(aapl, timestamps, 2, None) == 10) >>> ibm = Contract.create('IBM') >>> basket = Contract.create('AAPL_IBM', components=[(aapl, 1), (ibm, -1)]) >>> assert(pricefunc(basket, timestamps, 1, None) == -12) ''' price_dict: dict[str, tuple[np.ndarray, np.ndarray]] allow_previous: bool
[docs] def __init__(self, price_dict: dict[str, tuple[np.ndarray, np.ndarray]], allow_previous: bool = False) -> None: self.price_dict = price_dict self.allow_previous = allow_previous
[docs] def __call__(self, contract: Contract, timestamps: np.ndarray, i: int, context: StrategyContextType) -> float: price: float = 0. timestamp = timestamps[i] if contract.is_basket(): for _contract, ratio in contract.components: price += get_contract_price_from_array_dict(self.price_dict, _contract, timestamp, self.allow_previous) * ratio else: price = get_contract_price_from_array_dict(self.price_dict, contract, timestamp, self.allow_previous) return price
[docs] @dataclass class PriceFuncDict: ''' A function object with a signature of PriceFunctionType and takes a dictionary of contract name -> timestamp -> price >>> timestamps = np.arange(np.datetime64('2023-01-01'), np.datetime64('2023-01-04')) >>> aapl_prices = [8, 9, 10] >>> ibm_prices = [20, 21, 22] >>> price_dict = {'AAPL': {}, 'IBM': {}} >>> for i, timestamp in enumerate(timestamps): ... price_dict['AAPL'][timestamp] = aapl_prices[i] ... price_dict['IBM'][timestamp] = ibm_prices[i] >>> pricefunc = PriceFuncDict(price_dict) >>> Contract.clear_cache() >>> aapl = Contract.create('AAPL') >>> assert(pricefunc(aapl, timestamps, 2, None) == 10) >>> ibm = Contract.create('IBM') >>> basket = Contract.create('AAPL_IBM', components=[(aapl, 1), (ibm, -1)]) >>> assert(pricefunc(basket, timestamps, 1, None) == -12) ''' price_dict: dict[str, dict[np.datetime64, float]]
[docs] def __init__(self, price_dict: dict[str, dict[np.datetime64, float]]) -> None: self.price_dict = price_dict
[docs] def __call__(self, contract: Contract, timestamps: np.ndarray, i: int, context: StrategyContextType) -> float: timestamp = timestamps[i] price: float = 0. if contract.is_basket(): for _contract, ratio in contract.components: price += get_contract_price_from_dict(self.price_dict, _contract, timestamp) * ratio else: price = get_contract_price_from_dict(self.price_dict, contract, timestamp) return price
[docs] @dataclass class SimpleMarketSimulator: ''' A function object with a signature of MarketSimulatorType. It can take into account slippage and commission >>> ContractGroup.clear_cache() >>> Contract.clear_cache() >>> put_symbol, call_symbol = 'SPX-P-3500-2023-01-19', 'SPX-C-4000-2023-01-19' >>> put_contract = Contract.create(put_symbol) >>> call_contract = Contract.create(call_symbol) >>> basket = Contract.create('test_contract') >>> basket.components = [(put_contract, -1), (call_contract, 1)] >>> timestamp = np.datetime64('2023-01-03 14:35') >>> price_func = PriceFuncDict({put_symbol: {timestamp: 4.8}, call_symbol: {timestamp: 3.5}}) >>> order = MarketOrder(contract=basket, timestamp=timestamp, qty=10, reason_code='TEST') >>> sim = SimpleMarketSimulator(price_func=price_func, slippage_pct=0) >>> out = sim([order], 0, np.array([timestamp]), {}, {}, SimpleNamespace()) >>> assert(len(out) == 1) >>> assert(math.isclose(out[0].price, -1.3)) >>> assert(out[0].qty == 10) ''' price_func: PriceFunctionType slippage_pct: float commission: float price_rounding: int post_trade_func: Callable[[Trade, StrategyContextType], None] | None
[docs] def __init__(self, price_func: PriceFunctionType, slippage_pct: float = 0., commission: float = 0, price_rounding: int = 3, post_trade_func: Callable[[Trade, StrategyContextType], None] | None = None) -> None: ''' Args: price_func: A function that we use to get the price to execute at slippage_pct: Slippage per dollar transacted. Meant to simulate the difference between bid/ask mid and execution price commission: Fee paid to broker per trade ''' self.price_func = price_func self.slippage_pct = slippage_pct self.commission = commission self.price_rounding = price_rounding self.post_trade_func = post_trade_func
[docs] def __call__(self, orders: Sequence[Order], i: int, timestamps: np.ndarray, indicators: dict[str, SimpleNamespace], signals: dict[str, SimpleNamespace], strategy_context: SimpleNamespace) -> list[Trade]: '''TODO: code for stop orders''' trades = [] timestamp = timestamps[i] for order in orders: if not isinstance(order, MarketOrder) and not isinstance(order, LimitOrder): continue contract = order.contract if not contract.is_basket(): raw_price = self.price_func(contract, timestamps, i, strategy_context) else: raw_price = 0. for (_contract, ratio) in contract.components: raw_price += self.price_func(_contract, timestamps, i, strategy_context) * ratio if np.isnan(raw_price): break if np.isnan(raw_price): continue slippage = self.slippage_pct * raw_price if order.qty < 0: slippage = -slippage price = raw_price + slippage price = round(price, self.price_rounding) if isinstance(order, LimitOrder) and np.isfinite(order.limit_price): if ((abs(order.qty > 0) and order.limit_price > price) or (abs(order.qty < 0) and order.limit_price < price)): continue commission = self.commission * order.qty if order.qty < 0: commission = -commission trade = Trade(order.contract, order, timestamp, order.qty, price, self.commission) order.fill() trades.append(trade) if self.post_trade_func is not None: self.post_trade_func(trade, strategy_context) return trades
[docs] @dataclass class PercentOfEquityTradingRule: ''' A rule that trades a percentage of equity. Args: reason_code: Reason for entering the order, used for display equity_percent: Percentage of equity used to size order long: Whether We want to go long or short allocate_risk: If set, we divide the max risk by number of trades. Otherwise each trade will be alloated max risk limit_increment: If not nan, we add or subtract this number from current market price (if selling or buying respectively) and create a limit order. If nan, we create market orders price_func: The function we use to get intraday prices ''' reason_code: str price_func: PriceFunctionType equity_percent: float = 0.1 # use 10% of equity by default long: bool = True allocate_risk: bool = False limit_increment: float = math.nan
[docs] def __call__(self, contract_group: ContractGroup, i: int, timestamps: np.ndarray, indicator_values: SimpleNamespace, signal_values: np.ndarray, account: Account, current_orders: Sequence[Order], strategy_context: StrategyContextType) -> list[Order]: timestamp = timestamps[i] contracts = contract_group.get_contracts() orders: list[Order] = [] for contract in contracts: entry_price_est = self.price_func(contract, timestamps, i, strategy_context) if math.isnan(entry_price_est): return [] curr_equity = account.equity(timestamp) risk_amount = self.equity_percent * curr_equity order_qty = risk_amount / entry_price_est if self.allocate_risk: order_qty /= len(contracts) # divide up qty equally between all contracts if not self.long: order_qty *= -1 order_qty = math.floor(order_qty) if order_qty > 0 else math.ceil(order_qty) if math.isclose(order_qty, 0.): return [] if math.isfinite(self.limit_increment): entry_price_est = self.price_func(contract, timestamps, i, strategy_context) entry_price_est -= np.sign(order_qty) * self.limit_increment limit_order = LimitOrder(contract=contract, timestamp=timestamp, qty=order_qty, limit_price=entry_price_est, reason_code=self.reason_code) orders.append(limit_order) else: market_order = MarketOrder(contract=contract, timestamp=timestamp, qty=order_qty, reason_code=self.reason_code) orders.append(market_order) return orders
[docs] @dataclass class VWAPEntryRule: ''' A rule that generates VWAP orders Args: reason_code: Reason for each order. For display purposes vwap_minutes: How long the vwap period is. For example, a 5 minute vwap order will execute at 5 minute vwap from when it is sent to the market price_func: A function that this rule uses to get market price at a given timestamp long: Whether we want to go long or short percent_of_equity: Order qty is calculated so that if the stop price is reached, we lose this amount stop_price_ind: Don't enter if estimated entry price is market price <= stop price + min_price_diff_pct * market_price (for long orders) or the opposite for short orders min_price_diff_pct: See stop_price_ind ''' reason_code: str vwap_minutes: int price_func: PriceFunctionType long: bool percent_of_equity: float stop_price_ind: str | None min_price_diff_pct: float single_entry_per_day: bool
[docs] def __init__(self, reason_code: str, vwap_minutes: int, price_func: PriceFunctionType, long: bool = True, percent_of_equity: float = 0.1, stop_price_ind: str | None = None, min_price_diff_pct: float = 0, single_entry_per_day: bool = False) -> None: self.reason_code = reason_code self.price_func = price_func self.long = long self.vwap_minutes = vwap_minutes self.percent_of_equity = percent_of_equity self.stop_price_ind = stop_price_ind self.min_price_diff_pct = min_price_diff_pct self.single_entry_per_day = single_entry_per_day
[docs] def __call__(self, contract_group: ContractGroup, i: int, timestamps: np.ndarray, indicator_values: SimpleNamespace, signal_values: np.ndarray, account: Account, current_orders: Sequence[Order], strategy_context: StrategyContextType) -> list[Order]: timestamp = timestamps[i] if self.single_entry_per_day: date = timestamp.astype('M8[D]') trades = account.get_trades_for_date(contract_group.name, date) if len(trades): return [] for order in current_orders: if order.contract.contract_group == contract_group and order.is_open(): return [] orders: list[Order] = [] contracts = contract_group.get_contracts() for contract in contracts: entry_price_est = self.price_func(contract, timestamps, i, strategy_context) if math.isnan(entry_price_est): return [] if self.stop_price_ind: _stop_price_ind = getattr(indicator_values, self.stop_price_ind) stop_price = _stop_price_ind[i] if self.long and (entry_price_est - stop_price) < self.min_price_diff_pct * entry_price_est: return [] if not self.long and (stop_price - entry_price_est) < self.min_price_diff_pct * entry_price_est: return [] else: stop_price = math.nan curr_equity = account.equity(timestamp) risk_amount = self.percent_of_equity * curr_equity order_qty = risk_amount / (entry_price_est - stop_price) order_qty /= len(contracts) # divide up equity percentage equally order_qty = math.floor(order_qty) if order_qty > 0 else math.ceil(order_qty) if math.isclose(order_qty, 0.): return [] vwap_end_time = timestamp + np.timedelta64(self.vwap_minutes, 'm') order = VWAPOrder(contract=contract, timestamp=timestamp, vwap_stop=stop_price, vwap_end_time=vwap_end_time, qty=order_qty, time_in_force=TimeInForce.GTC, reason_code=self.reason_code) orders.append(order) return [order]
[docs] @dataclass class VWAPCloseRule: ''' Rule to close out a position at vwap price Args: reason_code: Reason_code: Reason for each order. For display purposes vwap_minutes: How long the vwap period is. For example, a 5 minute vwap order will execute at 5 minute vwap ''' reason_code: str vwap_minutes: int
[docs] def __init__(self, vwap_minutes: int, reason_code: str) -> None: self.vwap_minutes = vwap_minutes self.reason_code = reason_code
[docs] def __call__(self, contract_group: ContractGroup, i: int, timestamps: np.ndarray, indicator_values: SimpleNamespace, signal_values: np.ndarray, account: Account, current_orders: Sequence[Order], strategy_context: StrategyContextType) -> list[Order]: timestamp = timestamps[i] for order in current_orders: if order.contract.contract_group == contract_group and order.is_open(): return [] positions = account.positions(contract_group, timestamp) orders: list[Order] = [] for (contract, qty) in positions: # assert len(positions) == 1, f'expected 1 positions, got: {positions}' vwap_end_time = timestamp + np.timedelta64(self.vwap_minutes, 'm') order = VWAPOrder(contract=contract, # type: ignore timestamp=timestamp, vwap_end_time=vwap_end_time, qty=-qty, time_in_force=TimeInForce.GTC, reason_code=self.reason_code) orders.append(order) return orders
[docs] @dataclass class VWAPMarketSimulator: ''' A market simulator that simulates buying and selling at VWAP prices. This works with VWAP orders and ignores all other order types The order executes either: a. After the vwap end time defined in the VWAP order b. If marker price <= vwap stop price defined in the VWAP order for buy orders c. If market price >= vwap stop price for sell orders ''' price_indicator: str volume_indicator: str backup_price_indicator: str | None
[docs] def __init__(self, price_indicator: str, volume_indicator: str, backup_price_indicator: str | None = None) -> None: ''' Args: price_indicator: An indicator that contains historical trade price per timestamp volume_indicator: An indicator that contains volume per timestamp backup_price_indicator: Execution price to use if price or volume is missing ''' self.price_indicator = price_indicator self.volume_indicator = volume_indicator self.backup_price_indicator = backup_price_indicator
[docs] def __call__(self, orders: Sequence[Order], i: int, timestamps: np.ndarray, indicators: dict[ContractGroup, SimpleNamespace], signals: dict[ContractGroup, SimpleNamespace], strategy_context: SimpleNamespace) -> list[Trade]: trades = [] timestamp = timestamps[i] for order in orders: if not isinstance(order, VWAPOrder): continue cg = order.contract.contract_group inds = indicators.get(cg) assert_(inds is not None, f'indicators not found for contract group: {cg} {timestamp} {i}') price_ind = getattr(inds, self.price_indicator) # type: ignore assert_(price_ind is not None, f'indicator: {self.price_indicator} not found for contract group: {cg} {timestamp} {i}') volume_ind = getattr(inds, self.volume_indicator) # type: ignore assert_(volume_ind is not None, f'indicator: {self.volume_indicator} not found for contract group: {cg} {timestamp} {i}') end_order = False if math.isfinite(order.vwap_stop) and ( (order.qty >= 0 and price_ind[i] <= order.vwap_stop) or (order.qty < 0 and price_ind[i] >= order.vwap_stop)): end_order = True if not end_order and timestamp < order.vwap_end_time and i != len(timestamps) - 1 \ and not timestamps[i + 1].astype('M8[D]') > timestamps[i].astype('M8[D]'): continue mask = (price_ind > 0) & (volume_ind > 0) if end_order: mask &= (timestamps >= order.timestamp) & (timestamps <= timestamp) else: mask &= (timestamps >= order.timestamp) & (timestamps <= order.vwap_end_time) amt = price_ind[mask] * volume_ind[mask] if not len(amt): if order.qty <= 0: continue _logger.info(f'using backup price for {cg} {timestamp} {i} qty: {order.qty} {order}') assert_(self.backup_price_indicator is not None, f'backup price indicator not found and no vwap found for: {cg} {timestamp} {i}') _backup_price_ind = getattr(inds, self.backup_price_indicator) # type: ignore assert_(_backup_price_ind is not None, f'backup price indicator not found for: {cg} {timestamp} {i}') vwap = _backup_price_ind[i] else: vwap = np.sum(amt) / np.sum(volume_ind[mask]) assert_(vwap >= 0) fill_qty = order.qty if end_order: fill_fraction = (timestamp - order.timestamp) / (order.vwap_end_time - order.timestamp) fill_fraction = min(fill_fraction, 1) fill_qty = np.fix(order.qty * fill_fraction) order.fill(fill_qty) order.cancel() trade = Trade(order.contract, order, timestamp, fill_qty, vwap) trades.append(trade) return trades
ContractFilterType = Callable[ [ContractGroup, int, np.ndarray, SimpleNamespace, np.ndarray, Account, Sequence[Order], StrategyContextType], list[str]]
[docs] @dataclass class BracketOrderEntryRule: ''' A rule that generates orders with stops Args: reason_code: Reason for the orders created used for display price_func: A function that returns price given a contract and timestamp long: whether we want to go long or short percent_of_equity: How much to risk per trade as a percentage of current equity. Used to calculate order qty so that if we get stopped out, we don't lose more than this amount. Of course if price gaps up or down rather than moving smoothly, we may lose more. stop_return_func: A function that gives us the distance between entry price and stop price min_stop_return: We will not enter if the stop_return is closer than this percent to the stop. Otherwise, we get very large trade sizes max_position_size: An order should not result in a position that is greater than this percent of the portfolio contract_filter: A function that takes similar arguments as a rule (with ContractGroup) replaced by Contract but returns a list of contract names for each positive signal timestamp. For example, for a strategy that trades 5000 stocks, you may want to construct a single signal and apply it to different contracts at different times, rather than create 5000 signals that will call your rule 5000 times every time the signal is true. >>> timestamps = np.arange(np.datetime64('2023-01-01'), np.datetime64('2023-01-05')) >>> sig_values = np.full(len(timestamps), False) >>> aapl_prices = np.array([100.1, 100.2, 100.3, 100.4]) >>> ibm_prices = np.array([200.1, 200.2, 200.3, 200.4]) >>> aapl_stops = np.array([-0.5, -0.3, -0.2, -0.01]) >>> ibm_stops= np.array([-0.5, -0.3, -0.2, -0.15]) >>> price_dict = {'AAPL': (timestamps, aapl_prices), 'IBM': (timestamps, ibm_prices)} >>> stop_dict = {'AAPL': (timestamps, aapl_stops), 'IBM': (timestamps, ibm_stops)} >>> price_func = PriceFuncArrayDict(price_dict) >>> fr = BracketOrderEntryRule('TEST_ENTRY', price_func, long=False) >>> default_cg = ContractGroup.get('DEFAULT') >>> default_cg.clear_cache() >>> default_cg.add_contract(Contract.get_or_create('AAPL')) >>> default_cg.add_contract(Contract.get_or_create('IBM')) >>> account = SimpleNamespace() >>> account.equity = lambda x: 1e6 >>> orders = fr(default_cg, 1, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace()) >>> assert len(orders) == 2 and orders[0].qty == -998 and orders[1].qty == -499 >>> stop_return_func = PriceFuncArrayDict(stop_dict) >>> fr = BracketOrderEntryRule('TEST_ENTRY', price_func, long=True, stop_return_func=stop_return_func, min_stop_return=-0.1) >>> orders = fr(default_cg, 2, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace()) >>> assert len(orders) == 2 and orders[0].qty == 4985 and orders[1].qty == 2496 >>> orders = fr(default_cg, 3, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace()) >>> assert len(orders) == 1 and orders[0].qty == 3326 ''' reason_code: str price_func: PriceFunctionType long: bool percent_of_equity: float min_stop_returnt: float max_position_size: float single_entry_per_day: bool contract_filter: ContractFilterType | None stop_return_func: PriceFunctionType | None
[docs] def __init__(self, reason_code: str, price_func: PriceFunctionType, long: bool = True, percent_of_equity: float = 0.1, min_stop_return: float = 0, max_position_size: float = 0, single_entry_per_day: bool = False, contract_filter: ContractFilterType | None = None, stop_return_func: PriceFunctionType | None = None) -> None: self.reason_code = reason_code self.price_func = price_func self.long = long self.percent_of_equity = percent_of_equity self.stop_return_func = stop_return_func self.min_stop_return = min_stop_return self.max_position_size = max_position_size self.single_entry_per_day = single_entry_per_day self.contract_filter = contract_filter
[docs] def __call__(self, contract_group: ContractGroup, i: int, timestamps: np.ndarray, indicator_values: SimpleNamespace, signal_values: np.ndarray, account: Account, current_orders: Sequence[Order], strategy_context: StrategyContextType) -> list[Order]: timestamp = timestamps[i] date = timestamp.astype('M8[D]') contracts: list[Contract] = [] if self.contract_filter is not None: names = self.contract_filter( contract_group, i, timestamps, indicator_values, signal_values, account, current_orders, strategy_context) for name in names: _contract = Contract.get(name) if _contract is None: continue contracts.append(_contract) else: contracts = contract_group.get_contracts() orders: list[Order] = [] for contract in contracts: if self.single_entry_per_day: trades = account.get_trades_for_date(contract_group.name, date) if len(trades): continue entry_price_est = self.price_func(contract, timestamps, i, strategy_context) # type: ignore if math.isnan(entry_price_est): continue stop_price = 0. if self.stop_return_func is not None: stop_return = self.stop_return_func(contract, timestamps, i, strategy_context) # type: ignore assert_(stop_return < 0, f'stop_return must be negative: {stop_return} {timestamp} {contract.symbol}') if stop_return > self.min_stop_return: _logger.info(f'entry price estimate: {entry_price_est} too close to stop price: {stop_price}' f' min stop return: {self.min_stop_return}') continue if not self.long: stop_return = -stop_return stop_price = entry_price_est * (1 + stop_return) if ((self.long and entry_price_est <= stop_price) or (not self.long and entry_price_est >= stop_price)): _logger.info(f'entry price estimate: {entry_price_est} exceeds stop price: {stop_price}') continue curr_equity = account.equity(timestamp) risk_amount = self.percent_of_equity * curr_equity order_qty = risk_amount / abs(entry_price_est - stop_price) if self.max_position_size > 0: max_qty = abs(self.max_position_size * curr_equity / entry_price_est) orig_qty = order_qty if max_qty < abs(orig_qty): order_qty = max_qty * np.sign(orig_qty) _logger.info(f'max postion size exceeded: {contract} timestamp: {timestamp} max_qty: {max_qty}' f' orig_qty: {orig_qty} new qty: {order_qty} max_qty: {max_qty} pos size: {self.max_position_size}' f' curr_equity: {curr_equity} entry_price_est: {entry_price_est}') order_qty = math.floor(order_qty) if not self.long: order_qty = -order_qty if math.isclose(order_qty, 0.): continue order = MarketOrder(contract=contract, # type: ignore timestamp=timestamp, qty=order_qty, reason_code=self.reason_code) orders.append(order) return orders
[docs] @dataclass class ClosePositionExitRule: ''' A rule to close out an open position Args: reason_code: the reason for closing out used for display purposes price_func: the function this rule uses to get market prices limit_increment: if not nan, we add or subtract this number from current market price (if selling or buying respectively) and create a limit order ''' reason_code: str price_func: PriceFunctionType limit_increment: float
[docs] def __init__(self, reason_code: str, price_func: PriceFunctionType, limit_increment: float = math.nan) -> None: self.reason_code = reason_code self.price_func = price_func assert_(math.isnan(limit_increment) or limit_increment >= 0, f'limit_increment: {limit_increment} cannot be negative') self.limit_increment = limit_increment
[docs] def __call__(self, contract_group: ContractGroup, i: int, timestamps: np.ndarray, indicator_values: SimpleNamespace, signal_values: np.ndarray, account: Account, current_orders: Sequence[Order], strategy_context: StrategyContextType) -> list[Order]: timestamp = timestamps[i] positions = account.positions(contract_group, timestamp) orders: list[Order] = [] for (contract, qty) in positions: if math.isfinite(self.limit_increment): exit_price_est = self.price_func(contract, timestamps, i, strategy_context) if qty >= 0: exit_price_est += self.limit_increment else: exit_price_est -= self.limit_increment limit_order = LimitOrder(contract=contract, timestamp=timestamp, qty=-qty, limit_price=exit_price_est, reason_code=self.reason_code) orders.append(limit_order) continue order = MarketOrder(contract=contract, timestamp=timestamp, qty=-qty, reason_code=self.reason_code) orders.append(order) return orders
[docs] @dataclass class StopReturnExitRule: ''' A rule that exits any given positions if a stop is hit. You should set entry_price in the strategy context in the market simulator when you enter a position ''' reason_code: str price_func: PriceFunctionType stop_return_func: PriceFunctionType
[docs] def __call__(self, contract_group: ContractGroup, i: int, timestamps: np.ndarray, indicator_values: SimpleNamespace, signal_values: np.ndarray, account: Account, current_orders: Sequence[Order], context: StrategyContextType) -> list[Order]: timestamp = timestamps[i] date = timestamp.astype('M8[D]') entry_prices: dict[str, float] = context.entry_prices[date] assert_(len(entry_prices) > 0, f'no symbols entered for: {date}') positions = account.positions(contract_group, timestamp) orders: list[Order] = [] for contract, qty in positions: symbol = contract.symbol stop_ret = self.stop_return_func(contract, timestamps, i, context) assert_(stop_ret < 0, f'stop_return must be negative: {stop_ret} {timestamp} {symbol}') entry_price = entry_prices[symbol] if qty < 0: stop_ret = -stop_ret stop_price = entry_price * (1 + stop_ret) curr_price = self.price_func(contract, timestamps, i, context) if math.isnan(curr_price): continue if (qty > 0 and curr_price > stop_price) or (qty <= 0 and curr_price < stop_price): continue order = MarketOrder(contract=contract, timestamp=timestamp, qty=-qty, reason_code=self.reason_code) orders.append(order) return orders
if __name__ == "__main__": import doctest doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS) # $$_end_code