synaptic_sampling_rewardgradient_connection.h
1 /*
2  * This file is part of SPORE.
3  *
4  * Copyright (C) 2016, the SPORE team (see AUTHORS).
5  *
6  * SPORE is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * SPORE is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with SPORE. If not, see <http://www.gnu.org/licenses/>.
18  *
19  * For more information see: https://github.com/IGITUGraz/spore-nest-module
20  *
21  * File: synaptic_sampling_rewardgradient_connection.h
22  * Author: Kappel, Hsieh
23  *
24  * This file is based on stdp_dopa_connection.h which is part of NEST
25  * (Copyright (C) 2004 The NEST Initiative).
26  * See: http://nest-initiative.org/
27  */
28 
29 #ifndef SYNAPTIC_SAMPLING_REWARDGRADIENT_CONNECTIO
30 #define SYNAPTIC_SAMPLING_REWARDGRADIENT_CONNECTIO
31 
32 #include <cmath>
33 #include "nest.h"
34 #include "connection.h"
35 #include "normal_randomdev.h"
36 #include "spikecounter.h"
37 
38 #include "tracing_node.h"
39 #include "connection_updater.h"
40 #include "connection_data_logger.h"
41 #include "spore_names.h"
42 
43 #ifdef _OPENMP
44 #include <omp.h>
45 #endif
46 
47 
48 namespace spore
49 {
50 
57 class SynapticSamplingRewardGradientCommonProperties : public nest::CommonSynapseProperties
58 {
59 public:
60 
63 
64  using CommonSynapseProperties::get_status;
65  using CommonSynapseProperties::set_status;
66  using CommonSynapseProperties::calibrate;
67 
68  void get_status(DictionaryDatum& d) const;
69  void set_status(const DictionaryDatum& d, nest::ConnectorModel& cm);
70  void calibrate(const nest::TimeConverter& tc);
71 
75  void check_event(nest::SpikeEvent&)
76  {
77  }
78 
82  long get_vt_gid() const
83  {
84  if (reward_transmitter_ != 0)
85  return reward_transmitter_->get_gid();
86  else
87  return -1;
88  }
89 
93  nest::Node* get_node()
94  {
95  if (reward_transmitter_ == 0)
96  return nest::CommonSynapseProperties::get_node();
97  else
98  return reward_transmitter_;
99  }
100 
104  double get_d_wiener(nest::thread thread) const
105  {
106  double result = 0.0;
107  if (std_wiener_ > 0)
108  {
109  result = std_wiener_ * normal_dev_(nest::kernel().rng_manager.get_rng(thread));
110  }
111  return result;
112  }
113 
117  double get_gradient_noise(nest::thread thread) const
118  {
119  double result = 0.0;
120  if (std_gradient_ > 0)
121  {
122  result = std_gradient_ * normal_dev_(nest::kernel().rng_manager.get_rng(thread));
123  }
124  return result;
125  }
126 
130  double drand(nest::thread thread) const
131  {
132  return nest::kernel().rng_manager.get_rng(thread)->drand();
133  }
134 
135  // parameters
136  double learning_rate_;
137  double episode_length_;
138  double psp_tau_rise_;
139  double psp_tau_fall_;
140 
141  double temperature_;
142  double gradient_noise_;
143  double max_param_;
144  double min_param_;
145  double max_param_change_;
146  double integration_time_;
147  double direct_gradient_rate_;
148  double parameter_mapping_offset_;
149  double weight_scale_;
150  double weight_update_interval_;
151  double gradient_scale_;
152  double psp_cutoff_amplitude_;
153 
154  long bap_trace_id_;
155  long dopa_trace_id_;
156 
157  bool simulate_retracted_synapses_;
158  bool delete_retracted_synapses_;
159 
160  // state variables
161  TracingNode* reward_transmitter_;
162 
163  double resolution_unit_;
164  double reward_gradient_update_;
165  double eligibility_trace_update_;
166 
167  double psp_faciliation_update_;
168  double psp_depression_update_;
169  double psp_scale_factor_;
170 
171  long weight_update_steps_;
172 
173 private:
174 
175  double std_wiener_;
176  double std_gradient_;
177  librandom::NormalRandomDev normal_dev_;
178 };
179 
351 template<typename targetidentifierT>
352 class SynapticSamplingRewardGradientConnection : public nest::Connection<targetidentifierT>
353 {
354 public:
355 
359 
362 
364  typedef nest::Connection<targetidentifierT> ConnectionBase;
365 
371  void check_connection(nest::Node& s, nest::Node& t,
372  nest::rport receptor_type, double t_lastspike, const CommonPropertiesType& cp)
373  {
374  if (!dynamic_cast<TracingNode*> (&t))
375  {
376  throw nest::IllegalConnection("This synapse only works with nodes exposing their firing"
377  " probability trace (i.e. TracingNode-Subclass)!");
378  }
379 
380  ConnTestDummyNode dummy_target;
381  ConnectionBase::check_connection_(dummy_target, s, t, receptor_type);
382  }
383 
384  void get_status(DictionaryDatum& d) const;
385  void set_status(const DictionaryDatum& d, nest::ConnectorModel& cm);
386 
387  void send(nest::Event& e, nest::thread t, double t_lastspike, const CommonPropertiesType& cp);
388  void check_synapse_params( const DictionaryDatum& syn_spec ) const;
389 
390  using ConnectionBase::get_delay_steps;
391  using ConnectionBase::get_delay;
392  using ConnectionBase::get_rport;
393  using ConnectionBase::get_target;
394 
403  void set_weight(double w)
404  {
405  synaptic_parameter_ = w;
406  }
407 
411  double get_eligibility_trace() const
412  {
413  return eligibility_trace_;
414  }
415 
419  double get_psp() const
420  {
421  return psp_facilitation_ - psp_depression_;
422  }
423 
427  double get_weight() const
428  {
429  return weight_;
430  }
431 
435  double get_synaptic_parameter() const
436  {
437  return synaptic_parameter_;
438  }
439 
443  double get_reward_gradient() const
444  {
445  return reward_gradient_;
446  }
447 
451  bool is_degenerated() const
452  {
453  return (psp_facilitation_ == -1.0);
454  }
455 
457 
458 private:
459 
460  double weight_;
461  double synaptic_parameter_;
462 
463  double psp_facilitation_;
464  double psp_depression_;
465 
466  double eligibility_trace_;
467  double reward_gradient_;
468 
469  double prior_mean_;
470  double prior_precision_;
471 
472  nest::index recorder_port_;
473 
475 
476  void update_synapse_state(long t_to,
477  long t_last_update,
478  TracingNode::const_iterator& bap_trace,
479  TracingNode::const_iterator& dopa_trace,
480  const CommonPropertiesType& cp);
481 
482  void update_synapic_parameter(nest::thread thread, const CommonPropertiesType& cp);
483  void update_synapic_weight(long time_step, const CommonPropertiesType& cp);
484 
485  class ConnTestDummyNode : public nest::ConnTestDummyNodeBase
486  {
487  public:
488  using nest::ConnTestDummyNodeBase::handles_test_event;
489 
490  nest::port handles_test_event(nest::SpikeEvent&, nest::rport)
491  {
492  return nest::invalid_port_;
493  }
494 
495  nest::port handles_test_event(nest::DSSpikeEvent&, nest::rport)
496  {
497  return nest::invalid_port_;
498  }
499  };
500 };
501 
502 
503 //
504 // Implementation of SynapticSamplingRewardGradientConnection
505 //
506 
507 //
508 // Object life cycle
509 //
510 
514 template <typename targetidentifierT>
516 : ConnectionBase(),
517 weight_(0.0),
518 synaptic_parameter_(0.0),
519 psp_facilitation_(0.0),
520 psp_depression_(0.0),
521 eligibility_trace_(0.0),
522 reward_gradient_(0.0),
523 prior_mean_(0.0),
524 prior_precision_(1.0),
525 recorder_port_(nest::invalid_index)
526 {
527  // make sure the global logger object is instantiated here.
528  logger();
529 }
530 
534 template <typename targetidentifierT>
537 : ConnectionBase(rhs),
538 weight_(rhs.weight_),
539 synaptic_parameter_(rhs.synaptic_parameter_),
540 psp_facilitation_(rhs.psp_facilitation_),
541 psp_depression_(rhs.psp_depression_),
542 eligibility_trace_(rhs.eligibility_trace_),
543 reward_gradient_(rhs.reward_gradient_),
544 prior_mean_(rhs.prior_mean_),
545 prior_precision_(rhs.prior_precision_),
546 recorder_port_(nest::invalid_index)
547 {
548  // make sure the global logger object is instantiated here.
549  logger();
550 }
551 
555 template <typename targetidentifierT>
557 {
558 }
559 
560 //
561 // Instance of global data logger singleton.
562 //
563 
567 template <typename targetidentifierT>
570 
579 template <typename targetidentifierT>
582 {
583  if (!logger_)
584  {
585 #ifdef _OPENMP
586  // Setting up the logger is not thread-safe. This assertion protects the map that is
587  // build inside the data logger when the synapse is instantiated for the first time.
588  assert( not omp_in_parallel() );
589 #endif
590 
592 
593  logger_->register_recordable_variable(names::eligibility_trace_values,
594  &SynapticSamplingRewardGradientConnection::get_eligibility_trace);
595  logger_->register_recordable_variable(names::psp_values,
596  &SynapticSamplingRewardGradientConnection::get_psp);
597 
598  logger_->register_recordable_variable(names::weight_values,
599  &SynapticSamplingRewardGradientConnection::get_weight);
600 
601  logger_->register_recordable_variable(names::synaptic_parameter_values,
602  &SynapticSamplingRewardGradientConnection::get_synaptic_parameter);
603  logger_->register_recordable_variable(names::reward_gradient_values,
604  &SynapticSamplingRewardGradientConnection::get_reward_gradient);
605  }
606 
607  return logger_;
608 }
609 
610 //
611 // Parameter and state extractions and manipulation functions
612 //
613 
618 template <typename targetidentifierT>
620  const
621 {
622  // FIXME!! Check synaptic parameters here!
623 }
624 
628 template <typename targetidentifierT>
630 {
631  ConnectionBase::get_status(d);
632  def<double>(d, nest::names::weight, weight_);
633  def<double>(d, names::synaptic_parameter, synaptic_parameter_);
634  def<double>(d, names::eligibility_trace, eligibility_trace_);
635  def<double>(d, names::reward_gradient, reward_gradient_);
636  def<double>(d, names::prior_mean, prior_mean_);
637  def<double>(d, names::prior_precision, prior_precision_);
638  def<long>(d, nest::names::size_of, sizeof (*this));
639 
640  logger()->get_status(d, recorder_port_);
641 }
642 
648 template <typename targetidentifierT>
650  nest::ConnectorModel& cm)
651 {
652  ConnectionBase::set_status(d, cm);
653  updateValue<double>(d, nest::names::weight, weight_);
654  updateValue<double>(d, names::synaptic_parameter, synaptic_parameter_);
655  updateValue<double>(d, names::prior_mean, prior_mean_);
656  updateValue<double>(d, names::prior_precision, prior_precision_);
657 
658  logger()->set_status(d, recorder_port_);
659 }
660 
661 //
662 // Synapse event handling
663 //
664 
677 template <typename targetidentifierT>
679  nest::thread thread,
680  double t_last_spike,
681  const CommonPropertiesType& cp)
682 {
683  if (is_degenerated())
684  {
685  // synapse is waiting for the garbage collector.
686  return;
687  }
688 
689  assert(cp.resolution_unit_ > 0.0);
690 
691  const long s_to = std::floor( e.get_stamp().get_ms() / cp.resolution_unit_ );
692  long s_from = std::floor( t_last_spike / cp.resolution_unit_ );
693 
694  if (s_to > s_from)
695  {
696  if (s_from == 0)
697  {
698  update_synapic_weight(0, cp);
699  }
700  // prepare the pointer to the target neuron. We can safely static_cast
701  // since the connection is checked when established.
702  TracingNode* target = static_cast<TracingNode*> (get_target(thread));
703 
704  TracingNode::const_iterator bap_trace =
705  target->get_trace(s_from, cp.bap_trace_id_);
706 
707  TracingNode::const_iterator dopa_trace =
708  cp.reward_transmitter_->get_trace(s_from, cp.dopa_trace_id_);
709 
710  const double t_last_weight_update =
711  std::floor(t_last_spike / cp.weight_update_interval_) * cp.weight_update_interval_;
712  const long s_last_update = std::floor( t_last_weight_update/cp.resolution_unit_ );
713 
714  for (long next_weight_step = s_last_update + cp.weight_update_steps_;
715  next_weight_step <= s_to;
716  next_weight_step += cp.weight_update_steps_)
717  {
718  update_synapse_state(next_weight_step, s_from, bap_trace, dopa_trace, cp);
719  update_synapic_parameter(thread, cp);
720  update_synapic_weight(next_weight_step, cp);
721  s_from = next_weight_step;
722  }
723 
724  if (s_to > s_from)
725  {
726  update_synapse_state(s_to, s_from, bap_trace, dopa_trace, cp);
727  }
728  }
729 
730  if (cp.delete_retracted_synapses_ && (weight_==0.0))
731  {
732  // synapse prepares to be picked up by the garbage collector.
733  // invalid value of -1.0 for psp_facilitation_ is used to indicate
734  // synapses to be deleted. The synapse will be removed next time
735  // when the garbage collector is invoked.
736  psp_facilitation_ = -1.0;
737  nest::synindex syn_id = nest::Connection<targetidentifierT>::get_syn_id();
738  ConnectionUpdateManager::instance()->trigger_garbage_collector(get_target(thread)->get_gid(),
739  e.get_sender_gid(), thread, syn_id );
740  return;
741  }
742 
743  // Make sure that the event is not a SynapseUpdateEvent.
744  if (e.get_rport() >= 0)
745  {
746  // Apply presynaptic spike
747  psp_facilitation_ += 1.0;
748  psp_depression_ += 1.0;
749 
750  if (weight_ > 0.0)
751  {
752  e.set_weight(weight_);
753 
754  e.set_delay(get_delay_steps());
755  e.set_receiver(*get_target(thread));
756  e.set_rport(get_rport());
757  e();
758  }
759  }
760 }
761 
777 template <typename targetidentifierT>
779  long t_last_update,
781  bap_trace,
783  dopa_trace,
784  const CommonPropertiesType& cp)
785 {
786  if ((weight_ == 0.0) && not cp.simulate_retracted_synapses_)
787  {
788  // synapse is retracted. psps and eligibility traces are not going to be simulated.
789  return;
790  }
791 
792  assert(t_to >= t_last_update);
793  long steps = t_to - t_last_update;
794 
795  const double sc_psp = weight_ * cp.psp_scale_factor_;
796  const bool direct_gradient = cp.direct_gradient_rate_ > 0.0;
797  bool psp_active = (psp_facilitation_ != 0.0);
798 
799  while( steps )
800  {
801  // This loop will - considering every call - iterate through EVERY time step (in steps of resolution)
802 
803  // decay eligibility trace
804  eligibility_trace_ *= cp.eligibility_trace_update_;
805 
806  // decay gradient variable
807  reward_gradient_ *= cp.reward_gradient_update_;
808 
809  // update postsynaptic spike potential
810  if (psp_active)
811  {
812  psp_facilitation_ *= cp.psp_faciliation_update_;
813  psp_depression_ *= cp.psp_depression_update_;
814 
815  eligibility_trace_ += sc_psp * (psp_facilitation_ - psp_depression_) * (*bap_trace);
816 
817  if (psp_facilitation_ < cp.psp_cutoff_amplitude_)
818  {
819  psp_facilitation_ = 0.0;
820  psp_depression_ = 0.0;
821  psp_active = false;
822  }
823  }
824 
825  reward_gradient_ += (*dopa_trace) * eligibility_trace_;
826 
827  if (direct_gradient)
828  {
829  synaptic_parameter_ += (*dopa_trace) * cp.learning_rate_ *
830  cp.direct_gradient_rate_ * eligibility_trace_;
831  }
832 
833  ++bap_trace;
834  ++dopa_trace;
835  --steps;
836  }
837 }
838 
847 template < typename targetidentifierT >
849 update_synapic_parameter(nest::thread thread, const CommonPropertiesType& cp)
850 {
851  // update synaptic parameters
852  const double l_rate = cp.weight_update_interval_ * cp.learning_rate_;
853 
854  // compute prior
855  const double prior = prior_precision_ * (prior_mean_ - synaptic_parameter_);
856 
857  reward_gradient_ += cp.get_gradient_noise(thread);
858 
859  const double d_lik = std::max(-cp.max_param_change_,
860  std::min(cp.max_param_change_, cp.gradient_scale_ * reward_gradient_));
861 
862  const double d_param = l_rate * (prior + d_lik) + cp.get_d_wiener(thread);
863 
864  synaptic_parameter_ = std::max(cp.min_param_, std::min(cp.max_param_, synaptic_parameter_ + d_param));
865 }
866 
875 template < typename targetidentifierT >
877 update_synapic_weight(long time_step, const CommonPropertiesType& cp)
878 {
879  const bool synapse_is_active = (weight_ != 0.0) || (time_step==0);
880 
881  // update synaptic weight
882  if (synaptic_parameter_ >= 0.0)
883  {
884  weight_ = cp.weight_scale_ * std::exp(synaptic_parameter_ - cp.parameter_mapping_offset_);
885  }
886  else
887  {
888  weight_ = 0.0;
889  }
890 
891  if (synapse_is_active && not cp.simulate_retracted_synapses_ && (weight_ == 0.0))
892  {
893  // synapse is entering the retracted state, eligibility trace is reset.
894  psp_facilitation_ = 0.0;
895  psp_depression_ = 0.0;
896  eligibility_trace_ = 0.0;
897  reward_gradient_ = 0.0;
898  }
899 
900  logger()->record(time_step*cp.resolution_unit_, *this, recorder_port_);
901 }
902 
903 }
904 
905 #endif
double get_gradient_noise(nest::thread thread) const
Definition: synaptic_sampling_rewardgradient_connection.h:117
Reward-based synaptic sampling connection class.
Definition: synaptic_sampling_rewardgradient_connection.h:352
void calibrate(const nest::TimeConverter &tc)
Definition: synaptic_sampling_rewardgradient_connection.cpp:146
double drand(nest::thread thread) const
Definition: synaptic_sampling_rewardgradient_connection.h:130
long get_vt_gid() const
Definition: synaptic_sampling_rewardgradient_connection.h:82
SynapticSamplingRewardGradientConnection()
Definition: synaptic_sampling_rewardgradient_connection.h:515
Base class to all nodes that record traces.
Definition: tracing_node.h:52
void trigger_garbage_collector(nest::index target_gid, nest::index sender_gid, nest::thread target_thread, nest::synindex syn_id)
Definition: connection_updater.cpp:211
Definition: poisson_dbl_exp_neuron.cpp:43
Generic version of data logger for connections.
Definition: connection_data_logger.h:80
Class holding the common properties for all synapses of type SynapticSamplingRewardGradientConnection...
Definition: synaptic_sampling_rewardgradient_connection.h:57
void set_status(const DictionaryDatum &d, nest::ConnectorModel &cm)
Definition: synaptic_sampling_rewardgradient_connection.cpp:115
void check_connection(nest::Node &s, nest::Node &t, nest::rport receptor_type, double t_lastspike, const CommonPropertiesType &cp)
Definition: synaptic_sampling_rewardgradient_connection.h:371
Constant iterator class.
Definition: circular_buffer.h:49
void register_recordable_variable(const Name &name, DataAccessFct data_access_fct)
Definition: connection_data_logger.h:104
void check_synapse_params(const DictionaryDatum &syn_spec) const
Definition: synaptic_sampling_rewardgradient_connection.h:619
SynapticSamplingRewardGradientCommonProperties CommonPropertiesType
Type to use for representing common synapse properties.
Definition: synaptic_sampling_rewardgradient_connection.h:361
SynapticSamplingRewardGradientCommonProperties()
Definition: synaptic_sampling_rewardgradient_connection.cpp:68
void get_status(DictionaryDatum &d) const
Definition: synaptic_sampling_rewardgradient_connection.h:629
nest::Node * get_node()
Definition: synaptic_sampling_rewardgradient_connection.h:93
void check_event(nest::SpikeEvent &)
Definition: synaptic_sampling_rewardgradient_connection.h:75
void send(nest::Event &e, nest::thread t, double t_lastspike, const CommonPropertiesType &cp)
Definition: synaptic_sampling_rewardgradient_connection.h:678
static ConnectionUpdateManager * instance()
Definition: connection_updater.cpp:312
double get_d_wiener(nest::thread thread) const
Definition: synaptic_sampling_rewardgradient_connection.h:104
void get_status(DictionaryDatum &d) const
Definition: synaptic_sampling_rewardgradient_connection.cpp:95
Global namespace holding all classes of the SPORE NEST module.
Definition: circular_buffer.h:31
~SynapticSamplingRewardGradientCommonProperties()
Definition: synaptic_sampling_rewardgradient_connection.cpp:88
nest::Connection< targetidentifierT > ConnectionBase
Shortcut for base class.
Definition: synaptic_sampling_rewardgradient_connection.h:364
~SynapticSamplingRewardGradientConnection()
Definition: synaptic_sampling_rewardgradient_connection.h:556
void set_status(const DictionaryDatum &d, nest::ConnectorModel &cm)
Status setter function.
Definition: synaptic_sampling_rewardgradient_connection.h:649
const_iterator get_trace(nest::delay steps, trace_id id) const
Access the trace of id at time step step.
Definition: tracing_node.h:82