SSAGES  0.1
A MetaDynamics Package
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Groups Pages
StringMethod.cpp
1 
23 #include "ElasticBand.h"
24 #include "FiniteTempString.h"
25 #include "StringMethod.h"
26 #include "Swarm.h"
27 #include "CVs/CVManager.h"
28 #include "Validator/ObjectRequirement.h"
29 #include "Drivers/DriverException.h"
30 #include "Snapshot.h"
31 #include "spline.h"
32 #include "schema.h"
33 
34 using namespace Json;
35 
36 namespace SSAGES
37 {
38  void StringMethod::PrintString(const CVList& CV)
39  {
40  if(comm_.rank() == 0)
41  {
42  //Write node, iteration, centers of the string and current CV value to output file
43  stringout_.precision(8);
44  stringout_ << mpiid_ << " " << iteration_ << " ";
45 
46  for(size_t i = 0; i < centers_.size(); i++)
47  stringout_ << worldstring_[mpiid_][i] << " " << CV[i]->GetValue() << " ";
48 
49  stringout_ << std::endl;
50  }
51  }
52 
53  void StringMethod::GatherNeighbors(std::vector<double> *lcv0, std::vector<double> *ucv0)
54  {
55  MPI_Status status;
56 
57  if(comm_.rank() == 0)
58  {
59  MPI_Sendrecv(&centers_[0], centers_.size(), MPI_DOUBLE, sendneigh_, 1234,
60  &(*lcv0)[0], centers_.size(), MPI_DOUBLE, recneigh_, 1234,
61  world_, &status);
62 
63  MPI_Sendrecv(&centers_[0], centers_.size(), MPI_DOUBLE, recneigh_, 4321,
64  &(*ucv0)[0], centers_.size(), MPI_DOUBLE, sendneigh_, 4321,
65  world_, &status);
66  }
67 
68  MPI_Bcast(&(*lcv0)[0],centers_.size(),MPI_DOUBLE,0,comm_);
69  MPI_Bcast(&(*ucv0)[0],centers_.size(),MPI_DOUBLE,0,comm_);
70  }
71 
72  void StringMethod::StringReparam(double alpha_star)
73  {
74  std::vector<double> alpha_star_vector(numnodes_,0.0);
75 
76  //Reparameterization
77  //Alpha star is the uneven mesh, approximated as linear distance between points
78  if(comm_.rank()==0)
79  alpha_star_vector[mpiid_] = mpiid_ == 0 ? 0 : alpha_star;
80 
81  //Gather each alpha_star into a vector
82  MPI_Allreduce(MPI_IN_PLACE, &alpha_star_vector[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
83 
84  for(size_t i = 1; i < alpha_star_vector.size(); i++)
85  alpha_star_vector[i] += alpha_star_vector[i-1];
86 
87  for(size_t i = 1; i < alpha_star_vector.size(); i++)
88  alpha_star_vector[i] /= alpha_star_vector[numnodes_ - 1];
89 
90  tk::spline spl; //Cubic spline interpolation
91 
92  for(size_t i = 0; i < centers_.size(); i++)
93  {
94  std::vector<double> cvs_new(numnodes_, 0.0);
95 
96  if(comm_.rank() == 0)
97  cvs_new[mpiid_] = centers_[i];
98 
99  MPI_Allreduce(MPI_IN_PLACE, &cvs_new[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
100 
101  spl.set_points(alpha_star_vector, cvs_new);
102  centers_[i] = spl(mpiid_/(numnodes_ - 1.0));
103  }
104  }
105 
106  void StringMethod::UpdateWorldString(const CVList& cvs)
107  {
108  for(size_t i = 0; i < centers_.size(); i++)
109  {
110  std::vector<double> cvs_new(numnodes_, 0.0);
111 
112  if(comm_.rank() == 0)
113  {
114  cvs_new[mpiid_] = centers_[i];
115  }
116 
117  MPI_Allreduce(MPI_IN_PLACE, &cvs_new[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
118 
119  for(int j = 0; j < numnodes_; j++)
120  {
121  worldstring_[j][i] = cvs_new[j];
122  //Represent worldstring in periodic space
123  worldstring_[j][i] = cvs[i]->GetPeriodicValue(worldstring_[j][i]);
124  }
125  }
126  }
127 
128  bool StringMethod::CheckEnd(const CVList& CV)
129  {
130  if(maxiterator_ && iteration_ > maxiterator_)
131  {
132  std::cout << "System has reached max string method iterations (" << maxiterator_ << ") as specified in the input file(s)." << std::endl;
133  std::cout << "Exiting now" << std::endl;
134  PrintString(CV); //Ensure that the system prints out if it's about to exit
135  MPI_Abort(world_, EXIT_FAILURE);
136  }
137 
138  int local_tolvalue = TolCheck();
139 
140  MPI_Allreduce(MPI_IN_PLACE, &local_tolvalue, 1, MPI_INT, MPI_LAND, world_);
141 
142  if(local_tolvalue)
143  {
144  std::cout << "System has converged within tolerance criteria. Exiting now" << std::endl;
145  PrintString(CV); //Ensure that the system prints out if it's about to exit
146  MPI_Abort(world_, EXIT_FAILURE);
147  }
148 
149  return true;
150  }
151 
152  void StringMethod::PreSimulation(Snapshot* snapshot, const CVManager& cvmanager)
153  {
154  char file[1024];
155  mpiid_ = snapshot->GetWalkerID();
156  sprintf(file, "node-%04d.log",mpiid_);
157  stringout_.open(file);
158 
159  auto cvs = cvmanager.GetCVs(cvmask_);
160  SetSendRecvNeighbors();
161  worldstring_.resize(numnodes_);
162  for(auto& w : worldstring_)
163  w.resize(centers_.size());
164 
165  UpdateWorldString(cvs);
166  PrintString(cvs);
167  }
168 
169  void StringMethod::SetSendRecvNeighbors()
170  {
171  std::vector<int> wiids(world_.size(), 0);
172 
173  //Set the neighbors
174  recneigh_ = -1;
175  sendneigh_ = -1;
176 
177  MPI_Allgather(&mpiid_, 1, MPI_INT, &wiids[0], 1, MPI_INT, world_);
178  numnodes_ = int(*std::max_element(wiids.begin(), wiids.end())) + 1;
179 
180  // Ugly for now...
181  for(size_t i = 0; i < wiids.size(); i++)
182  {
183  if(mpiid_ == 0)
184  {
185  sendneigh_ = comm_.size();
186  if(wiids[i] == numnodes_ - 1)
187  {
188  recneigh_ = i;
189  break;
190  }
191  }
192  else if (mpiid_ == numnodes_ - 1)
193  {
194  sendneigh_ = 0;
195  if(wiids[i] == mpiid_ - 1)
196  {
197  recneigh_ = i;
198  break;
199  }
200  }
201  else
202  {
203  if(wiids[i] == mpiid_ + 1)
204  {
205  sendneigh_ = i;
206  break;
207  }
208  if(wiids[i] == mpiid_ - 1 && recneigh_ == -1)
209  recneigh_ = i;
210  }
211  }
212  }
213 
215  StringMethod* StringMethod::Build(const Value& json,
216  const MPI_Comm& world,
217  const MPI_Comm& comm,
218  const std::string& path)
219  {
220  ObjectRequirement validator;
221  Value schema;
222  Reader reader;
223 
224  StringMethod* m = nullptr;
225 
226  reader.parse(JsonSchema::StringMethod, schema);
227  validator.Parse(schema, path);
228 
229  // Validate inputs.
230  validator.Validate(json, path);
231  if(validator.HasErrors())
232  throw BuildException(validator.GetErrors());
233 
234  unsigned int wid = mxx::comm(world).rank()/mxx::comm(comm).size();
235  std::vector<double> centers;
236  for(auto& s : json["centers"][wid])
237  centers.push_back(s.asDouble());
238 
239  auto maxiterator = json.get("max_iterations", 0).asInt();
240 
241  std::vector<double> ksprings;
242  for(auto& s : json["ksprings"])
243  ksprings.push_back(s.asDouble());
244 
245  auto freq = json.get("frequency", 1).asInt();
246 
247  // Get stringmethod flavor.
248  std::string flavor = json.get("flavor", "none").asString();
249  if(flavor == "ElasticBand")
250  {
251  reader.parse(JsonSchema::ElasticBandMethod, schema);
252  validator.Parse(schema, path);
253 
254  // Validate inputs.
255  validator.Validate(json, path);
256  if(validator.HasErrors())
257  throw BuildException(validator.GetErrors());
258 
259  auto eqsteps = json.get("equilibration_steps", 20).asInt();
260  auto evsteps = json.get("evolution_steps", 5).asInt();
261  auto stringspring = json.get("kstring", 10.0).asDouble();
262  auto isteps = json.get("block_iterations", 100).asInt();
263  auto tau = json.get("time_step", 0.1).asDouble();
264 
265  m = new ElasticBand(world, comm, centers,
266  maxiterator, isteps,
267  tau, ksprings, eqsteps,
268  evsteps, stringspring, freq);
269 
270  if(json.isMember("tolerance"))
271  {
272  std::vector<double> tol;
273  for(auto& s : json["tolerance"])
274  tol.push_back(s.asDouble());
275 
276  m->SetTolerance(tol);
277  }
278  }
279  else if(flavor == "FTS")
280  {
281  reader.parse(JsonSchema::FTSMethod, schema);
282  validator.Parse(schema, path);
283 
284  // Validate inputs.
285  validator.Validate(json, path);
286  if(validator.HasErrors())
287  throw BuildException(validator.GetErrors());
288 
289  auto isteps = json.get("block_iterations", 2000).asInt();
290  auto tau = json.get("time_step", 0.1).asDouble();
291  auto kappa = json.get("kappa", 0.1).asDouble();
292  auto springiter = json.get("umbrella_iterations",2000).asDouble();
293  m = new FiniteTempString(world, comm, centers,
294  maxiterator, isteps,
295  tau, ksprings, kappa,
296  springiter, freq);
297 
298  if(json.isMember("tolerance"))
299  {
300  std::vector<double> tol;
301  for(auto& s : json["tolerance"])
302  tol.push_back(s.asDouble());
303 
304  m->SetTolerance(tol);
305  }
306  }
307  else if(flavor == "SWARM")
308  {
309  reader.parse(JsonSchema::SwarmMethod, schema);
310  validator.Parse(schema, path);
311 
312  //Validate input
313  validator.Validate(json, path);
314  if(validator.HasErrors())
315  throw BuildException(validator.GetErrors());
316 
317  auto InitialSteps = json.get("initial_steps", 2500).asInt();
318  auto HarvestLength = json.get("harvest_length", 10).asInt();
319  auto NumberTrajectories = json.get("number_of_trajectories", 250).asInt();
320  auto SwarmLength = json.get("swarm_length", 20).asInt();
321 
322  m = new Swarm(world, comm, centers, maxiterator, ksprings, freq, InitialSteps, HarvestLength, NumberTrajectories, SwarmLength);
323 
324  if(json.isMember("tolerance"))
325  {
326  std::vector<double> tol;
327  for(auto& s : json["tolerance"])
328  tol.push_back(s.asDouble());
329 
330  m->SetTolerance(tol);
331  }
332  }
333 
334  return m;
335  }
336 
337 }
unsigned GetWalkerID() const
Get walker ID.
Definition: Snapshot.h:193
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
Collective variable manager.
Definition: CVManager.h:40
Finite Temperature Spring Method.
std::vector< CollectiveVariable * > CVList
List of Collective Variables.
Definition: types.h:51
String base class for FTS, Swarm, and elastic band.
Definition: StringMethod.h:38
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:43
Swarm of Trajectories String Method.
Definition: Swarm.h:32
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
Exception to be thrown when building the Driver fails.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
std::vector< CollectiveVariable * > GetCVs(const std::vector< uint > &mask=std::vector< uint >()) const
Get CV iterator.
Definition: CVManager.h:80
void SetTolerance(std::vector< double > tol)
Set the tolerance for quitting method.
Definition: StringMethod.h:202
Requirements on an object.
Multi-walker Elastic Band.
Definition: ElasticBand.h:34
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.