29 #ifndef SYNAPTIC_SAMPLING_REWARDGRADIENT_CONNECTIO 30 #define SYNAPTIC_SAMPLING_REWARDGRADIENT_CONNECTIO 34 #include "connection.h" 35 #include "normal_randomdev.h" 36 #include "spikecounter.h" 38 #include "tracing_node.h" 39 #include "connection_updater.h" 40 #include "connection_data_logger.h" 41 #include "spore_names.h" 64 using CommonSynapseProperties::get_status;
65 using CommonSynapseProperties::set_status;
66 using CommonSynapseProperties::calibrate;
69 void set_status(
const DictionaryDatum& d, nest::ConnectorModel& cm);
70 void calibrate(
const nest::TimeConverter& tc);
84 if (reward_transmitter_ != 0)
85 return reward_transmitter_->get_gid();
95 if (reward_transmitter_ == 0)
96 return nest::CommonSynapseProperties::get_node();
98 return reward_transmitter_;
109 result = std_wiener_ * normal_dev_(nest::kernel().rng_manager.get_rng(thread));
120 if (std_gradient_ > 0)
122 result = std_gradient_ * normal_dev_(nest::kernel().rng_manager.get_rng(thread));
130 double drand(nest::thread thread)
const 132 return nest::kernel().rng_manager.get_rng(thread)->drand();
136 double learning_rate_;
137 double episode_length_;
138 double psp_tau_rise_;
139 double psp_tau_fall_;
142 double gradient_noise_;
145 double max_param_change_;
146 double integration_time_;
147 double direct_gradient_rate_;
148 double parameter_mapping_offset_;
149 double weight_scale_;
150 double weight_update_interval_;
151 double gradient_scale_;
152 double psp_cutoff_amplitude_;
157 bool simulate_retracted_synapses_;
158 bool delete_retracted_synapses_;
163 double resolution_unit_;
164 double reward_gradient_update_;
165 double eligibility_trace_update_;
167 double psp_faciliation_update_;
168 double psp_depression_update_;
169 double psp_scale_factor_;
171 long weight_update_steps_;
176 double std_gradient_;
177 librandom::NormalRandomDev normal_dev_;
351 template<
typename target
identifierT>
372 nest::rport receptor_type,
double t_lastspike,
const CommonPropertiesType& cp)
374 if (!dynamic_cast<TracingNode*> (&t))
376 throw nest::IllegalConnection(
"This synapse only works with nodes exposing their firing" 377 " probability trace (i.e. TracingNode-Subclass)!");
380 ConnTestDummyNode dummy_target;
381 ConnectionBase::check_connection_(dummy_target, s, t, receptor_type);
385 void set_status(
const DictionaryDatum& d, nest::ConnectorModel& cm);
387 void send(nest::Event& e, nest::thread t,
double t_lastspike,
const CommonPropertiesType& cp);
388 void check_synapse_params(
const DictionaryDatum& syn_spec )
const;
390 using ConnectionBase::get_delay_steps;
391 using ConnectionBase::get_delay;
392 using ConnectionBase::get_rport;
393 using ConnectionBase::get_target;
403 void set_weight(
double w)
405 synaptic_parameter_ = w;
411 double get_eligibility_trace()
const 413 return eligibility_trace_;
419 double get_psp()
const 421 return psp_facilitation_ - psp_depression_;
427 double get_weight()
const 435 double get_synaptic_parameter()
const 437 return synaptic_parameter_;
443 double get_reward_gradient()
const 445 return reward_gradient_;
451 bool is_degenerated()
const 453 return (psp_facilitation_ == -1.0);
461 double synaptic_parameter_;
463 double psp_facilitation_;
464 double psp_depression_;
466 double eligibility_trace_;
467 double reward_gradient_;
470 double prior_precision_;
472 nest::index recorder_port_;
476 void update_synapse_state(
long t_to,
480 const CommonPropertiesType& cp);
482 void update_synapic_parameter(nest::thread thread,
const CommonPropertiesType& cp);
483 void update_synapic_weight(
long time_step,
const CommonPropertiesType& cp);
485 class ConnTestDummyNode :
public nest::ConnTestDummyNodeBase
488 using nest::ConnTestDummyNodeBase::handles_test_event;
490 nest::port handles_test_event(nest::SpikeEvent&, nest::rport)
492 return nest::invalid_port_;
495 nest::port handles_test_event(nest::DSSpikeEvent&, nest::rport)
497 return nest::invalid_port_;
514 template <
typename target
identifierT>
518 synaptic_parameter_(0.0),
519 psp_facilitation_(0.0),
520 psp_depression_(0.0),
521 eligibility_trace_(0.0),
522 reward_gradient_(0.0),
524 prior_precision_(1.0),
525 recorder_port_(
nest::invalid_index)
534 template <
typename target
identifierT>
538 weight_(rhs.weight_),
539 synaptic_parameter_(rhs.synaptic_parameter_),
540 psp_facilitation_(rhs.psp_facilitation_),
541 psp_depression_(rhs.psp_depression_),
542 eligibility_trace_(rhs.eligibility_trace_),
543 reward_gradient_(rhs.reward_gradient_),
544 prior_mean_(rhs.prior_mean_),
545 prior_precision_(rhs.prior_precision_),
546 recorder_port_(
nest::invalid_index)
555 template <
typename target
identifierT>
567 template <
typename target
identifierT>
579 template <
typename target
identifierT>
588 assert( not omp_in_parallel() );
594 &SynapticSamplingRewardGradientConnection::get_eligibility_trace);
595 logger_->register_recordable_variable(names::psp_values,
596 &SynapticSamplingRewardGradientConnection::get_psp);
598 logger_->register_recordable_variable(names::weight_values,
599 &SynapticSamplingRewardGradientConnection::get_weight);
601 logger_->register_recordable_variable(names::synaptic_parameter_values,
602 &SynapticSamplingRewardGradientConnection::get_synaptic_parameter);
603 logger_->register_recordable_variable(names::reward_gradient_values,
604 &SynapticSamplingRewardGradientConnection::get_reward_gradient);
618 template <
typename target
identifierT>
628 template <
typename target
identifierT>
631 ConnectionBase::get_status(d);
632 def<double>(d, nest::names::weight, weight_);
633 def<double>(d, names::synaptic_parameter, synaptic_parameter_);
634 def<double>(d, names::eligibility_trace, eligibility_trace_);
635 def<double>(d, names::reward_gradient, reward_gradient_);
636 def<double>(d, names::prior_mean, prior_mean_);
637 def<double>(d, names::prior_precision, prior_precision_);
638 def<long>(d, nest::names::size_of,
sizeof (*this));
640 logger()->get_status(d, recorder_port_);
648 template <
typename target
identifierT>
650 nest::ConnectorModel& cm)
652 ConnectionBase::set_status(d, cm);
653 updateValue<double>(d, nest::names::weight, weight_);
654 updateValue<double>(d, names::synaptic_parameter, synaptic_parameter_);
655 updateValue<double>(d, names::prior_mean, prior_mean_);
656 updateValue<double>(d, names::prior_precision, prior_precision_);
658 logger()->set_status(d, recorder_port_);
677 template <
typename target
identifierT>
683 if (is_degenerated())
689 assert(cp.resolution_unit_ > 0.0);
691 const long s_to = std::floor( e.get_stamp().get_ms() / cp.resolution_unit_ );
692 long s_from = std::floor( t_last_spike / cp.resolution_unit_ );
698 update_synapic_weight(0, cp);
705 target->
get_trace(s_from, cp.bap_trace_id_);
708 cp.reward_transmitter_->
get_trace(s_from, cp.dopa_trace_id_);
710 const double t_last_weight_update =
711 std::floor(t_last_spike / cp.weight_update_interval_) * cp.weight_update_interval_;
712 const long s_last_update = std::floor( t_last_weight_update/cp.resolution_unit_ );
714 for (
long next_weight_step = s_last_update + cp.weight_update_steps_;
715 next_weight_step <= s_to;
716 next_weight_step += cp.weight_update_steps_)
718 update_synapse_state(next_weight_step, s_from, bap_trace, dopa_trace, cp);
719 update_synapic_parameter(thread, cp);
720 update_synapic_weight(next_weight_step, cp);
721 s_from = next_weight_step;
726 update_synapse_state(s_to, s_from, bap_trace, dopa_trace, cp);
730 if (cp.delete_retracted_synapses_ && (weight_==0.0))
736 psp_facilitation_ = -1.0;
737 nest::synindex syn_id = nest::Connection<targetidentifierT>::get_syn_id();
739 e.get_sender_gid(), thread, syn_id );
744 if (e.get_rport() >= 0)
747 psp_facilitation_ += 1.0;
748 psp_depression_ += 1.0;
752 e.set_weight(weight_);
754 e.set_delay(get_delay_steps());
755 e.set_receiver(*get_target(thread));
756 e.set_rport(get_rport());
777 template <
typename target
identifierT>
786 if ((weight_ == 0.0) && not cp.simulate_retracted_synapses_)
792 assert(t_to >= t_last_update);
793 long steps = t_to - t_last_update;
795 const double sc_psp = weight_ * cp.psp_scale_factor_;
796 const bool direct_gradient = cp.direct_gradient_rate_ > 0.0;
797 bool psp_active = (psp_facilitation_ != 0.0);
804 eligibility_trace_ *= cp.eligibility_trace_update_;
807 reward_gradient_ *= cp.reward_gradient_update_;
812 psp_facilitation_ *= cp.psp_faciliation_update_;
813 psp_depression_ *= cp.psp_depression_update_;
815 eligibility_trace_ += sc_psp * (psp_facilitation_ - psp_depression_) * (*bap_trace);
817 if (psp_facilitation_ < cp.psp_cutoff_amplitude_)
819 psp_facilitation_ = 0.0;
820 psp_depression_ = 0.0;
825 reward_gradient_ += (*dopa_trace) * eligibility_trace_;
829 synaptic_parameter_ += (*dopa_trace) * cp.learning_rate_ *
830 cp.direct_gradient_rate_ * eligibility_trace_;
847 template <
typename target
identifierT >
852 const double l_rate = cp.weight_update_interval_ * cp.learning_rate_;
855 const double prior = prior_precision_ * (prior_mean_ - synaptic_parameter_);
859 const double d_lik = std::max(-cp.max_param_change_,
860 std::min(cp.max_param_change_, cp.gradient_scale_ * reward_gradient_));
862 const double d_param = l_rate * (prior + d_lik) + cp.
get_d_wiener(thread);
864 synaptic_parameter_ = std::max(cp.min_param_, std::min(cp.max_param_, synaptic_parameter_ + d_param));
875 template <
typename target
identifierT >
879 const bool synapse_is_active = (weight_ != 0.0) || (time_step==0);
882 if (synaptic_parameter_ >= 0.0)
884 weight_ = cp.weight_scale_ * std::exp(synaptic_parameter_ - cp.parameter_mapping_offset_);
891 if (synapse_is_active && not cp.simulate_retracted_synapses_ && (weight_ == 0.0))
894 psp_facilitation_ = 0.0;
895 psp_depression_ = 0.0;
896 eligibility_trace_ = 0.0;
897 reward_gradient_ = 0.0;
900 logger()->record(time_step*cp.resolution_unit_, *
this, recorder_port_);
double get_gradient_noise(nest::thread thread) const
Definition: synaptic_sampling_rewardgradient_connection.h:117
Reward-based synaptic sampling connection class.
Definition: synaptic_sampling_rewardgradient_connection.h:352
void calibrate(const nest::TimeConverter &tc)
Definition: synaptic_sampling_rewardgradient_connection.cpp:146
double drand(nest::thread thread) const
Definition: synaptic_sampling_rewardgradient_connection.h:130
long get_vt_gid() const
Definition: synaptic_sampling_rewardgradient_connection.h:82
SynapticSamplingRewardGradientConnection()
Definition: synaptic_sampling_rewardgradient_connection.h:515
Base class to all nodes that record traces.
Definition: tracing_node.h:52
void trigger_garbage_collector(nest::index target_gid, nest::index sender_gid, nest::thread target_thread, nest::synindex syn_id)
Definition: connection_updater.cpp:211
Definition: poisson_dbl_exp_neuron.cpp:43
Generic version of data logger for connections.
Definition: connection_data_logger.h:80
Class holding the common properties for all synapses of type SynapticSamplingRewardGradientConnection...
Definition: synaptic_sampling_rewardgradient_connection.h:57
void set_status(const DictionaryDatum &d, nest::ConnectorModel &cm)
Definition: synaptic_sampling_rewardgradient_connection.cpp:115
void check_connection(nest::Node &s, nest::Node &t, nest::rport receptor_type, double t_lastspike, const CommonPropertiesType &cp)
Definition: synaptic_sampling_rewardgradient_connection.h:371
Constant iterator class.
Definition: circular_buffer.h:49
void register_recordable_variable(const Name &name, DataAccessFct data_access_fct)
Definition: connection_data_logger.h:104
void check_synapse_params(const DictionaryDatum &syn_spec) const
Definition: synaptic_sampling_rewardgradient_connection.h:619
SynapticSamplingRewardGradientCommonProperties CommonPropertiesType
Type to use for representing common synapse properties.
Definition: synaptic_sampling_rewardgradient_connection.h:361
SynapticSamplingRewardGradientCommonProperties()
Definition: synaptic_sampling_rewardgradient_connection.cpp:68
void get_status(DictionaryDatum &d) const
Definition: synaptic_sampling_rewardgradient_connection.h:629
nest::Node * get_node()
Definition: synaptic_sampling_rewardgradient_connection.h:93
void check_event(nest::SpikeEvent &)
Definition: synaptic_sampling_rewardgradient_connection.h:75
void send(nest::Event &e, nest::thread t, double t_lastspike, const CommonPropertiesType &cp)
Definition: synaptic_sampling_rewardgradient_connection.h:678
static ConnectionUpdateManager * instance()
Definition: connection_updater.cpp:312
double get_d_wiener(nest::thread thread) const
Definition: synaptic_sampling_rewardgradient_connection.h:104
void get_status(DictionaryDatum &d) const
Definition: synaptic_sampling_rewardgradient_connection.cpp:95
Global namespace holding all classes of the SPORE NEST module.
Definition: circular_buffer.h:31
~SynapticSamplingRewardGradientCommonProperties()
Definition: synaptic_sampling_rewardgradient_connection.cpp:88
nest::Connection< targetidentifierT > ConnectionBase
Shortcut for base class.
Definition: synaptic_sampling_rewardgradient_connection.h:364
~SynapticSamplingRewardGradientConnection()
Definition: synaptic_sampling_rewardgradient_connection.h:556
void set_status(const DictionaryDatum &d, nest::ConnectorModel &cm)
Status setter function.
Definition: synaptic_sampling_rewardgradient_connection.h:649
const_iterator get_trace(nest::delay steps, trace_id id) const
Access the trace of id at time step step.
Definition: tracing_node.h:82