# $$_ Lines starting with # $$_* autogenerated by jup_mini. Do not modify these
# $$_markdown
# # Strategy Components
# ## Purpose
# Helper components to build strategies with common use cases like VWAP entry and exit, and finite risk
# $$_end_markdown
# $$_code
# $$_ %%checkall
import numpy as np
from dataclasses import dataclass
import math
from types import SimpleNamespace
from typing import Sequence, Callable
from pyqstrat.account import Account
from pyqstrat.pq_types import Contract, ContractGroup, Trade, Order, VWAPOrder
from pyqstrat.pq_types import MarketOrder, LimitOrder, TimeInForce
from pyqstrat.strategy import PriceFunctionType, StrategyContextType
from pyqstrat.pq_utils import assert_, get_child_logger, np_indexof_sorted
_logger = get_child_logger(__name__)
[docs]
@dataclass
class VectorIndicator:
'''
An indicator created from a vector
Args:
vector: Vector with indicator values. Must be the same length as strategy timestamps
'''
vector: np.ndarray
[docs]
def __call__(self,
contract_group: ContractGroup,
timestamps: np.ndarray,
indicator_values: SimpleNamespace,
context: StrategyContextType) -> np.ndarray:
return self.vector
[docs]
@dataclass
class VectorSignal:
'''
A signal created from a vector that has boolean values
Args:
vector: Vector with indicator values. Must be the same length as strategy timestamps
'''
vector: np.ndarray
[docs]
def __call__(self,
contract_group: ContractGroup,
timestamps: np.ndarray,
indicator_values: SimpleNamespace,
parent_values: SimpleNamespace,
context: StrategyContextType) -> np.ndarray:
return self.vector
[docs]
def get_contract_price_from_dict(price_dict: dict[str, dict[np.datetime64, float]],
contract: Contract,
timestamp: np.datetime64) -> float:
assert_(contract.symbol in price_dict, f'{contract.symbol} not found in price_dict')
ret = price_dict[contract.symbol].get(timestamp)
if ret is None: return math.nan
return ret
[docs]
def get_contract_price_from_array_dict(price_dict: dict[str, tuple[np.ndarray, np.ndarray]],
contract: Contract,
timestamp: np.datetime64,
allow_previous: bool) -> float:
tup: tuple[np.ndarray, np.ndarray] | None = price_dict.get(contract.symbol)
assert_(tup is not None, f'{contract.symbol} not found in price_dict')
_timestamps, _prices = tup # type: ignore
idx: int
if allow_previous:
idx = int(np.searchsorted(_timestamps, timestamp, side='right')) - 1
else:
idx = np_indexof_sorted(_timestamps, timestamp)
if idx == -1: return math.nan
# if idx >= len(_prices):
# import pdb
# pdb.set_trace()
return _prices[idx] # type: ignore
[docs]
@dataclass
class PriceFuncArrays:
'''
A function object with a signature of PriceFunctionType. Takes three ndarrays
of symbols, timestamps and prices
'''
price_dict: dict[str, tuple[np.ndarray, np.ndarray]]
allow_previous: bool
[docs]
def __init__(self, symbols: np.ndarray, timestamps: np.ndarray, prices: np.ndarray, allow_previous: bool = False) -> None:
assert_(len(timestamps) == len(symbols) and len(prices) == len(symbols),
f'arrays have different sizes: {len(timestamps)} {len(symbols)} {len(prices)}')
price_dict: dict[str, tuple[np.ndarray, np.ndarray]] = {}
for symbol in np.unique(symbols):
mask = (symbols == symbol)
price_dict[symbol] = (timestamps[mask], prices[mask])
self.price_dict = price_dict
self.allow_previous = allow_previous
[docs]
def __call__(self, contract: Contract, timestamps: np.ndarray, i: int, context: StrategyContextType) -> float:
price: float = 0.
timestamp = timestamps[i]
if contract.is_basket():
for _contract, ratio in contract.components:
price += get_contract_price_from_array_dict(self.price_dict, _contract, timestamp, self.allow_previous) * ratio
else:
price = get_contract_price_from_array_dict(self.price_dict, contract, timestamp, self.allow_previous)
return price
[docs]
@dataclass
class PriceFuncArrayDict:
'''
A function object with a signature of PriceFunctionType and takes a dictionary of
contract name -> tuple of sorted timestamps and prices
Args:
price_dict: a dict with key=contract nane and value a tuple of timestamp and price arrays
allow_previous: if set and we don't find an exact match for the timestamp, use the
previous timestamp. Useful if you have a dict with keys containing dates instead of timestamps
>>> timestamps = np.arange(np.datetime64('2023-01-01'), np.datetime64('2023-01-04'))
>>> price_dict = {'AAPL': (timestamps, [8, 9, 10]), 'IBM': (timestamps, [20, 21, 22])}
>>> pricefunc = PriceFuncArrayDict(price_dict)
>>> Contract.clear_cache()
>>> aapl = Contract.create('AAPL')
>>> assert(pricefunc(aapl, timestamps, 2, None) == 10)
>>> ibm = Contract.create('IBM')
>>> basket = Contract.create('AAPL_IBM', components=[(aapl, 1), (ibm, -1)])
>>> assert(pricefunc(basket, timestamps, 1, None) == -12)
'''
price_dict: dict[str, tuple[np.ndarray, np.ndarray]]
allow_previous: bool
[docs]
def __init__(self, price_dict: dict[str, tuple[np.ndarray, np.ndarray]], allow_previous: bool = False) -> None:
self.price_dict = price_dict
self.allow_previous = allow_previous
[docs]
def __call__(self, contract: Contract, timestamps: np.ndarray, i: int, context: StrategyContextType) -> float:
price: float = 0.
timestamp = timestamps[i]
if contract.is_basket():
for _contract, ratio in contract.components:
price += get_contract_price_from_array_dict(self.price_dict, _contract, timestamp, self.allow_previous) * ratio
else:
price = get_contract_price_from_array_dict(self.price_dict, contract, timestamp, self.allow_previous)
return price
[docs]
@dataclass
class PriceFuncDict:
'''
A function object with a signature of PriceFunctionType and takes a dictionary of contract name -> timestamp -> price
>>> timestamps = np.arange(np.datetime64('2023-01-01'), np.datetime64('2023-01-04'))
>>> aapl_prices = [8, 9, 10]
>>> ibm_prices = [20, 21, 22]
>>> price_dict = {'AAPL': {}, 'IBM': {}}
>>> for i, timestamp in enumerate(timestamps):
... price_dict['AAPL'][timestamp] = aapl_prices[i]
... price_dict['IBM'][timestamp] = ibm_prices[i]
>>> pricefunc = PriceFuncDict(price_dict)
>>> Contract.clear_cache()
>>> aapl = Contract.create('AAPL')
>>> assert(pricefunc(aapl, timestamps, 2, None) == 10)
>>> ibm = Contract.create('IBM')
>>> basket = Contract.create('AAPL_IBM', components=[(aapl, 1), (ibm, -1)])
>>> assert(pricefunc(basket, timestamps, 1, None) == -12)
'''
price_dict: dict[str, dict[np.datetime64, float]]
[docs]
def __init__(self, price_dict: dict[str, dict[np.datetime64, float]]) -> None:
self.price_dict = price_dict
[docs]
def __call__(self, contract: Contract, timestamps: np.ndarray, i: int, context: StrategyContextType) -> float:
timestamp = timestamps[i]
price: float = 0.
if contract.is_basket():
for _contract, ratio in contract.components:
price += get_contract_price_from_dict(self.price_dict, _contract, timestamp) * ratio
else:
price = get_contract_price_from_dict(self.price_dict, contract, timestamp)
return price
[docs]
@dataclass
class SimpleMarketSimulator:
'''
A function object with a signature of MarketSimulatorType.
It can take into account slippage and commission
>>> ContractGroup.clear_cache()
>>> Contract.clear_cache()
>>> put_symbol, call_symbol = 'SPX-P-3500-2023-01-19', 'SPX-C-4000-2023-01-19'
>>> put_contract = Contract.create(put_symbol)
>>> call_contract = Contract.create(call_symbol)
>>> basket = Contract.create('test_contract')
>>> basket.components = [(put_contract, -1), (call_contract, 1)]
>>> timestamp = np.datetime64('2023-01-03 14:35')
>>> price_func = PriceFuncDict({put_symbol: {timestamp: 4.8}, call_symbol: {timestamp: 3.5}})
>>> order = MarketOrder(contract=basket, timestamp=timestamp, qty=10, reason_code='TEST')
>>> sim = SimpleMarketSimulator(price_func=price_func, slippage_pct=0)
>>> out = sim([order], 0, np.array([timestamp]), {}, {}, SimpleNamespace())
>>> assert(len(out) == 1)
>>> assert(math.isclose(out[0].price, -1.3))
>>> assert(out[0].qty == 10)
'''
price_func: PriceFunctionType
slippage_pct: float
commission: float
price_rounding: int
post_trade_func: Callable[[Trade, StrategyContextType], None] | None
[docs]
def __init__(self,
price_func: PriceFunctionType,
slippage_pct: float = 0.,
commission: float = 0,
price_rounding: int = 3,
post_trade_func: Callable[[Trade, StrategyContextType], None] | None = None) -> None:
'''
Args:
price_func: A function that we use to get the price to execute at
slippage_pct: Slippage per dollar transacted.
Meant to simulate the difference between bid/ask mid and execution price
commission: Fee paid to broker per trade
'''
self.price_func = price_func
self.slippage_pct = slippage_pct
self.commission = commission
self.price_rounding = price_rounding
self.post_trade_func = post_trade_func
[docs]
def __call__(self,
orders: Sequence[Order],
i: int,
timestamps: np.ndarray,
indicators: dict[str, SimpleNamespace],
signals: dict[str, SimpleNamespace],
strategy_context: SimpleNamespace) -> list[Trade]:
'''TODO: code for stop orders'''
trades = []
timestamp = timestamps[i]
for order in orders:
if not isinstance(order, MarketOrder) and not isinstance(order, LimitOrder): continue
contract = order.contract
if not contract.is_basket():
raw_price = self.price_func(contract, timestamps, i, strategy_context)
else:
raw_price = 0.
for (_contract, ratio) in contract.components:
raw_price += self.price_func(_contract, timestamps, i, strategy_context) * ratio
if np.isnan(raw_price):
break
if np.isnan(raw_price): continue
slippage = self.slippage_pct * raw_price
if order.qty < 0: slippage = -slippage
price = raw_price + slippage
price = round(price, self.price_rounding)
if isinstance(order, LimitOrder) and np.isfinite(order.limit_price):
if ((abs(order.qty > 0) and order.limit_price > price)
or (abs(order.qty < 0) and order.limit_price < price)):
continue
commission = self.commission * order.qty
if order.qty < 0: commission = -commission
trade = Trade(order.contract, order, timestamp, order.qty, price, self.commission)
order.fill()
trades.append(trade)
if self.post_trade_func is not None:
self.post_trade_func(trade, strategy_context)
return trades
[docs]
@dataclass
class PercentOfEquityTradingRule:
'''
A rule that trades a percentage of equity.
Args:
reason_code: Reason for entering the order, used for display
equity_percent: Percentage of equity used to size order
long: Whether We want to go long or short
allocate_risk: If set, we divide the max risk by number of trades. Otherwise each trade will be alloated max risk
limit_increment: If not nan, we add or subtract this number from current market price (if selling or buying respectively)
and create a limit order. If nan, we create market orders
price_func: The function we use to get intraday prices
'''
reason_code: str
price_func: PriceFunctionType
equity_percent: float = 0.1 # use 10% of equity by default
long: bool = True
allocate_risk: bool = False
limit_increment: float = math.nan
[docs]
def __call__(self,
contract_group: ContractGroup,
i: int,
timestamps: np.ndarray,
indicator_values: SimpleNamespace,
signal_values: np.ndarray,
account: Account,
current_orders: Sequence[Order],
strategy_context: StrategyContextType) -> list[Order]:
timestamp = timestamps[i]
contracts = contract_group.get_contracts()
orders: list[Order] = []
for contract in contracts:
entry_price_est = self.price_func(contract, timestamps, i, strategy_context)
if math.isnan(entry_price_est): return []
curr_equity = account.equity(timestamp)
risk_amount = self.equity_percent * curr_equity
order_qty = risk_amount / entry_price_est
if self.allocate_risk:
order_qty /= len(contracts) # divide up qty equally between all contracts
if not self.long: order_qty *= -1
order_qty = math.floor(order_qty) if order_qty > 0 else math.ceil(order_qty)
if math.isclose(order_qty, 0.): return []
if math.isfinite(self.limit_increment):
entry_price_est = self.price_func(contract, timestamps, i, strategy_context)
entry_price_est -= np.sign(order_qty) * self.limit_increment
limit_order = LimitOrder(contract=contract,
timestamp=timestamp,
qty=order_qty,
limit_price=entry_price_est,
reason_code=self.reason_code)
orders.append(limit_order)
else:
market_order = MarketOrder(contract=contract,
timestamp=timestamp,
qty=order_qty,
reason_code=self.reason_code)
orders.append(market_order)
return orders
[docs]
@dataclass
class VWAPEntryRule:
'''
A rule that generates VWAP orders
Args:
reason_code: Reason for each order. For display purposes
vwap_minutes: How long the vwap period is. For example, a 5 minute vwap order will execute at 5 minute vwap
from when it is sent to the market
price_func: A function that this rule uses to get market price at a given timestamp
long: Whether we want to go long or short
percent_of_equity: Order qty is calculated so that if the stop price is reached, we lose this amount
stop_price_ind: Don't enter if estimated entry price is
market price <= stop price + min_price_diff_pct * market_price (for long orders) or the opposite for short orders
min_price_diff_pct: See stop_price_ind
'''
reason_code: str
vwap_minutes: int
price_func: PriceFunctionType
long: bool
percent_of_equity: float
stop_price_ind: str | None
min_price_diff_pct: float
single_entry_per_day: bool
[docs]
def __init__(self,
reason_code: str,
vwap_minutes: int,
price_func: PriceFunctionType,
long: bool = True,
percent_of_equity: float = 0.1,
stop_price_ind: str | None = None,
min_price_diff_pct: float = 0,
single_entry_per_day: bool = False) -> None:
self.reason_code = reason_code
self.price_func = price_func
self.long = long
self.vwap_minutes = vwap_minutes
self.percent_of_equity = percent_of_equity
self.stop_price_ind = stop_price_ind
self.min_price_diff_pct = min_price_diff_pct
self.single_entry_per_day = single_entry_per_day
[docs]
def __call__(self,
contract_group: ContractGroup,
i: int,
timestamps: np.ndarray,
indicator_values: SimpleNamespace,
signal_values: np.ndarray,
account: Account,
current_orders: Sequence[Order],
strategy_context: StrategyContextType) -> list[Order]:
timestamp = timestamps[i]
if self.single_entry_per_day:
date = timestamp.astype('M8[D]')
trades = account.get_trades_for_date(contract_group.name, date)
if len(trades): return []
for order in current_orders:
if order.contract.contract_group == contract_group and order.is_open(): return []
orders: list[Order] = []
contracts = contract_group.get_contracts()
for contract in contracts:
entry_price_est = self.price_func(contract, timestamps, i, strategy_context)
if math.isnan(entry_price_est): return []
if self.stop_price_ind:
_stop_price_ind = getattr(indicator_values, self.stop_price_ind)
stop_price = _stop_price_ind[i]
if self.long and (entry_price_est - stop_price) < self.min_price_diff_pct * entry_price_est: return []
if not self.long and (stop_price - entry_price_est) < self.min_price_diff_pct * entry_price_est: return []
else:
stop_price = math.nan
curr_equity = account.equity(timestamp)
risk_amount = self.percent_of_equity * curr_equity
order_qty = risk_amount / (entry_price_est - stop_price)
order_qty /= len(contracts) # divide up equity percentage equally
order_qty = math.floor(order_qty) if order_qty > 0 else math.ceil(order_qty)
if math.isclose(order_qty, 0.): return []
vwap_end_time = timestamp + np.timedelta64(self.vwap_minutes, 'm')
order = VWAPOrder(contract=contract,
timestamp=timestamp,
vwap_stop=stop_price,
vwap_end_time=vwap_end_time,
qty=order_qty,
time_in_force=TimeInForce.GTC,
reason_code=self.reason_code)
orders.append(order)
return [order]
[docs]
@dataclass
class VWAPCloseRule:
'''
Rule to close out a position at vwap price
Args:
reason_code: Reason_code: Reason for each order. For display purposes
vwap_minutes: How long the vwap period is. For example, a 5 minute vwap order will execute at 5 minute vwap
'''
reason_code: str
vwap_minutes: int
[docs]
def __init__(self,
vwap_minutes: int,
reason_code: str) -> None:
self.vwap_minutes = vwap_minutes
self.reason_code = reason_code
[docs]
def __call__(self,
contract_group: ContractGroup,
i: int,
timestamps: np.ndarray,
indicator_values: SimpleNamespace,
signal_values: np.ndarray,
account: Account,
current_orders: Sequence[Order],
strategy_context: StrategyContextType) -> list[Order]:
timestamp = timestamps[i]
for order in current_orders:
if order.contract.contract_group == contract_group and order.is_open(): return []
positions = account.positions(contract_group, timestamp)
orders: list[Order] = []
for (contract, qty) in positions:
# assert len(positions) == 1, f'expected 1 positions, got: {positions}'
vwap_end_time = timestamp + np.timedelta64(self.vwap_minutes, 'm')
order = VWAPOrder(contract=contract, # type: ignore
timestamp=timestamp,
vwap_end_time=vwap_end_time,
qty=-qty,
time_in_force=TimeInForce.GTC,
reason_code=self.reason_code)
orders.append(order)
return orders
[docs]
@dataclass
class VWAPMarketSimulator:
'''
A market simulator that simulates buying and selling at VWAP prices.
This works with VWAP orders and ignores all other order types
The order executes either:
a. After the vwap end time defined in the VWAP order
b. If marker price <= vwap stop price defined in the VWAP order for buy orders
c. If market price >= vwap stop price for sell orders
'''
price_indicator: str
volume_indicator: str
backup_price_indicator: str | None
[docs]
def __init__(self,
price_indicator: str,
volume_indicator: str,
backup_price_indicator: str | None = None) -> None:
'''
Args:
price_indicator: An indicator that contains historical trade price per timestamp
volume_indicator: An indicator that contains volume per timestamp
backup_price_indicator: Execution price to use if price or volume is missing
'''
self.price_indicator = price_indicator
self.volume_indicator = volume_indicator
self.backup_price_indicator = backup_price_indicator
[docs]
def __call__(self,
orders: Sequence[Order],
i: int,
timestamps: np.ndarray,
indicators: dict[ContractGroup, SimpleNamespace],
signals: dict[ContractGroup, SimpleNamespace],
strategy_context: SimpleNamespace) -> list[Trade]:
trades = []
timestamp = timestamps[i]
for order in orders:
if not isinstance(order, VWAPOrder): continue
cg = order.contract.contract_group
inds = indicators.get(cg)
assert_(inds is not None, f'indicators not found for contract group: {cg} {timestamp} {i}')
price_ind = getattr(inds, self.price_indicator) # type: ignore
assert_(price_ind is not None, f'indicator: {self.price_indicator} not found for contract group: {cg} {timestamp} {i}')
volume_ind = getattr(inds, self.volume_indicator) # type: ignore
assert_(volume_ind is not None, f'indicator: {self.volume_indicator} not found for contract group: {cg} {timestamp} {i}')
end_order = False
if math.isfinite(order.vwap_stop) and (
(order.qty >= 0 and price_ind[i] <= order.vwap_stop) or (order.qty < 0 and price_ind[i] >= order.vwap_stop)):
end_order = True
if not end_order and timestamp < order.vwap_end_time and i != len(timestamps) - 1 \
and not timestamps[i + 1].astype('M8[D]') > timestamps[i].astype('M8[D]'):
continue
mask = (price_ind > 0) & (volume_ind > 0)
if end_order:
mask &= (timestamps >= order.timestamp) & (timestamps <= timestamp)
else:
mask &= (timestamps >= order.timestamp) & (timestamps <= order.vwap_end_time)
amt = price_ind[mask] * volume_ind[mask]
if not len(amt):
if order.qty <= 0: continue
_logger.info(f'using backup price for {cg} {timestamp} {i} qty: {order.qty} {order}')
assert_(self.backup_price_indicator is not None,
f'backup price indicator not found and no vwap found for: {cg} {timestamp} {i}')
_backup_price_ind = getattr(inds, self.backup_price_indicator) # type: ignore
assert_(_backup_price_ind is not None, f'backup price indicator not found for: {cg} {timestamp} {i}')
vwap = _backup_price_ind[i]
else:
vwap = np.sum(amt) / np.sum(volume_ind[mask])
assert_(vwap >= 0)
fill_qty = order.qty
if end_order:
fill_fraction = (timestamp - order.timestamp) / (order.vwap_end_time - order.timestamp)
fill_fraction = min(fill_fraction, 1)
fill_qty = np.fix(order.qty * fill_fraction)
order.fill(fill_qty)
order.cancel()
trade = Trade(order.contract, order, timestamp, fill_qty, vwap)
trades.append(trade)
return trades
ContractFilterType = Callable[
[ContractGroup,
int,
np.ndarray,
SimpleNamespace,
np.ndarray,
Account,
Sequence[Order],
StrategyContextType],
list[str]]
[docs]
@dataclass
class BracketOrderEntryRule:
'''
A rule that generates orders with stops
Args:
reason_code: Reason for the orders created used for display
price_func: A function that returns price given a contract and timestamp
long: whether we want to go long or short
percent_of_equity: How much to risk per trade as a percentage of current equity.
Used to calculate order qty so that if we get stopped out, we don't lose
more than this amount. Of course if price gaps up or down rather than moving smoothly,
we may lose more.
stop_return_func: A function that gives us the distance between entry price and stop price
min_stop_return: We will not enter if the stop_return is closer than this percent to the stop.
Otherwise, we get very large trade sizes
max_position_size: An order should not result in a position that is greater than this percent
of the portfolio
contract_filter: A function that takes similar arguments as a rule (with ContractGroup) replaced by
Contract but returns a list of contract names for each positive signal timestamp. For example,
for a strategy that trades 5000 stocks, you may want to construct a single signal and apply it
to different contracts at different times, rather than create 5000 signals that will call your rule
5000 times every time the signal is true.
>>> timestamps = np.arange(np.datetime64('2023-01-01'), np.datetime64('2023-01-05'))
>>> sig_values = np.full(len(timestamps), False)
>>> aapl_prices = np.array([100.1, 100.2, 100.3, 100.4])
>>> ibm_prices = np.array([200.1, 200.2, 200.3, 200.4])
>>> aapl_stops = np.array([-0.5, -0.3, -0.2, -0.01])
>>> ibm_stops= np.array([-0.5, -0.3, -0.2, -0.15])
>>> price_dict = {'AAPL': (timestamps, aapl_prices), 'IBM': (timestamps, ibm_prices)}
>>> stop_dict = {'AAPL': (timestamps, aapl_stops), 'IBM': (timestamps, ibm_stops)}
>>> price_func = PriceFuncArrayDict(price_dict)
>>> fr = BracketOrderEntryRule('TEST_ENTRY', price_func, long=False)
>>> default_cg = ContractGroup.get('DEFAULT')
>>> default_cg.clear_cache()
>>> default_cg.add_contract(Contract.get_or_create('AAPL'))
>>> default_cg.add_contract(Contract.get_or_create('IBM'))
>>> account = SimpleNamespace()
>>> account.equity = lambda x: 1e6
>>> orders = fr(default_cg, 1, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace())
>>> assert len(orders) == 2 and orders[0].qty == -998 and orders[1].qty == -499
>>> stop_return_func = PriceFuncArrayDict(stop_dict)
>>> fr = BracketOrderEntryRule('TEST_ENTRY', price_func, long=True, stop_return_func=stop_return_func, min_stop_return=-0.1)
>>> orders = fr(default_cg, 2, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace())
>>> assert len(orders) == 2 and orders[0].qty == 4985 and orders[1].qty == 2496
>>> orders = fr(default_cg, 3, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace())
>>> assert len(orders) == 1 and orders[0].qty == 3326
'''
reason_code: str
price_func: PriceFunctionType
long: bool
percent_of_equity: float
min_stop_returnt: float
max_position_size: float
single_entry_per_day: bool
contract_filter: ContractFilterType | None
stop_return_func: PriceFunctionType | None
[docs]
def __init__(self,
reason_code: str,
price_func: PriceFunctionType,
long: bool = True,
percent_of_equity: float = 0.1,
min_stop_return: float = 0,
max_position_size: float = 0,
single_entry_per_day: bool = False,
contract_filter: ContractFilterType | None = None,
stop_return_func: PriceFunctionType | None = None) -> None:
self.reason_code = reason_code
self.price_func = price_func
self.long = long
self.percent_of_equity = percent_of_equity
self.stop_return_func = stop_return_func
self.min_stop_return = min_stop_return
self.max_position_size = max_position_size
self.single_entry_per_day = single_entry_per_day
self.contract_filter = contract_filter
[docs]
def __call__(self,
contract_group: ContractGroup,
i: int,
timestamps: np.ndarray,
indicator_values: SimpleNamespace,
signal_values: np.ndarray,
account: Account,
current_orders: Sequence[Order],
strategy_context: StrategyContextType) -> list[Order]:
timestamp = timestamps[i]
date = timestamp.astype('M8[D]')
contracts: list[Contract] = []
if self.contract_filter is not None:
names = self.contract_filter(
contract_group, i, timestamps, indicator_values, signal_values, account, current_orders, strategy_context)
for name in names:
_contract = Contract.get(name)
if _contract is None: continue
contracts.append(_contract)
else:
contracts = contract_group.get_contracts()
orders: list[Order] = []
for contract in contracts:
if self.single_entry_per_day:
trades = account.get_trades_for_date(contract_group.name, date)
if len(trades): continue
entry_price_est = self.price_func(contract, timestamps, i, strategy_context) # type: ignore
if math.isnan(entry_price_est): continue
stop_price = 0.
if self.stop_return_func is not None:
stop_return = self.stop_return_func(contract, timestamps, i, strategy_context) # type: ignore
assert_(stop_return < 0, f'stop_return must be negative: {stop_return} {timestamp} {contract.symbol}')
if stop_return > self.min_stop_return:
_logger.info(f'entry price estimate: {entry_price_est} too close to stop price: {stop_price}'
f' min stop return: {self.min_stop_return}')
continue
if not self.long: stop_return = -stop_return
stop_price = entry_price_est * (1 + stop_return)
if ((self.long and entry_price_est <= stop_price) or (not self.long and entry_price_est >= stop_price)):
_logger.info(f'entry price estimate: {entry_price_est} exceeds stop price: {stop_price}')
continue
curr_equity = account.equity(timestamp)
risk_amount = self.percent_of_equity * curr_equity
order_qty = risk_amount / abs(entry_price_est - stop_price)
if self.max_position_size > 0:
max_qty = abs(self.max_position_size * curr_equity / entry_price_est)
orig_qty = order_qty
if max_qty < abs(orig_qty):
order_qty = max_qty * np.sign(orig_qty)
_logger.info(f'max postion size exceeded: {contract} timestamp: {timestamp} max_qty: {max_qty}'
f' orig_qty: {orig_qty} new qty: {order_qty} max_qty: {max_qty} pos size: {self.max_position_size}'
f' curr_equity: {curr_equity} entry_price_est: {entry_price_est}')
order_qty = math.floor(order_qty)
if not self.long: order_qty = -order_qty
if math.isclose(order_qty, 0.): continue
order = MarketOrder(contract=contract, # type: ignore
timestamp=timestamp,
qty=order_qty,
reason_code=self.reason_code)
orders.append(order)
return orders
[docs]
@dataclass
class ClosePositionExitRule:
'''
A rule to close out an open position
Args:
reason_code: the reason for closing out used for display purposes
price_func: the function this rule uses to get market prices
limit_increment: if not nan, we add or subtract this number from current market price (if selling or buying respectively)
and create a limit order
'''
reason_code: str
price_func: PriceFunctionType
limit_increment: float
[docs]
def __init__(self,
reason_code: str,
price_func: PriceFunctionType,
limit_increment: float = math.nan) -> None:
self.reason_code = reason_code
self.price_func = price_func
assert_(math.isnan(limit_increment) or limit_increment >= 0, f'limit_increment: {limit_increment} cannot be negative')
self.limit_increment = limit_increment
[docs]
def __call__(self,
contract_group: ContractGroup,
i: int,
timestamps: np.ndarray,
indicator_values: SimpleNamespace,
signal_values: np.ndarray,
account: Account,
current_orders: Sequence[Order],
strategy_context: StrategyContextType) -> list[Order]:
timestamp = timestamps[i]
positions = account.positions(contract_group, timestamp)
orders: list[Order] = []
for (contract, qty) in positions:
if math.isfinite(self.limit_increment):
exit_price_est = self.price_func(contract, timestamps, i, strategy_context)
if qty >= 0:
exit_price_est += self.limit_increment
else:
exit_price_est -= self.limit_increment
limit_order = LimitOrder(contract=contract, timestamp=timestamp, qty=-qty, limit_price=exit_price_est, reason_code=self.reason_code)
orders.append(limit_order)
continue
order = MarketOrder(contract=contract, timestamp=timestamp, qty=-qty, reason_code=self.reason_code)
orders.append(order)
return orders
[docs]
@dataclass
class StopReturnExitRule:
'''
A rule that exits any given positions if a stop is hit.
You should set entry_price in the strategy context in the market simulator when you enter a position
'''
reason_code: str
price_func: PriceFunctionType
stop_return_func: PriceFunctionType
[docs]
def __call__(self,
contract_group: ContractGroup,
i: int,
timestamps: np.ndarray,
indicator_values: SimpleNamespace,
signal_values: np.ndarray,
account: Account,
current_orders: Sequence[Order],
context: StrategyContextType) -> list[Order]:
timestamp = timestamps[i]
date = timestamp.astype('M8[D]')
entry_prices: dict[str, float] = context.entry_prices[date]
assert_(len(entry_prices) > 0, f'no symbols entered for: {date}')
positions = account.positions(contract_group, timestamp)
orders: list[Order] = []
for contract, qty in positions:
symbol = contract.symbol
stop_ret = self.stop_return_func(contract, timestamps, i, context)
assert_(stop_ret < 0, f'stop_return must be negative: {stop_ret} {timestamp} {symbol}')
entry_price = entry_prices[symbol]
if qty < 0: stop_ret = -stop_ret
stop_price = entry_price * (1 + stop_ret)
curr_price = self.price_func(contract, timestamps, i, context)
if math.isnan(curr_price): continue
if (qty > 0 and curr_price > stop_price) or (qty <= 0 and curr_price < stop_price): continue
order = MarketOrder(contract=contract, timestamp=timestamp, qty=-qty, reason_code=self.reason_code)
orders.append(order)
return orders
if __name__ == "__main__":
import doctest
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS)
# $$_end_code